Skip to content

Commit 679f2fc

Browse files
aporialiaofacebook-github-bot
authored andcommitted
Simplify permute indices for Sharded EBC output_dist creation (#2856)
Summary: Pull Request resolved: #2856 Simplify create_output_dist call See suggestion in next diff Reviewed By: iamzainhuda Differential Revision: D72079015 fbshipit-source-id: 37d9bada0b6742db27da69e1b62ab3073145a98e
1 parent 5b57613 commit 679f2fc

File tree

1 file changed

+7
-10
lines changed

1 file changed

+7
-10
lines changed

torchrec/distributed/embeddingbag.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,16 +1172,13 @@ def _create_output_dist(self) -> None:
11721172
for i, name in enumerate(self._uncombined_embedding_names):
11731173
embedding_name_order.setdefault(name, i)
11741174

1175-
def sort_key(input: Tuple[int, str]) -> Tuple[int, int]:
1176-
index, name = input
1177-
return (embedding_name_order[name], embedding_shard_offsets[index])
1178-
1179-
permute_indices = [
1180-
i
1181-
for i, _ in sorted(
1182-
enumerate(self._uncombined_embedding_names), key=sort_key
1183-
)
1184-
]
1175+
permute_indices = sorted(
1176+
range(len(self._uncombined_embedding_names)),
1177+
key=lambda i: (
1178+
embedding_name_order[self._uncombined_embedding_names[i]],
1179+
embedding_shard_offsets[i],
1180+
),
1181+
)
11851182
self._permute_op: PermutePooledEmbeddings = PermutePooledEmbeddings(
11861183
self._uncombined_embedding_dims, permute_indices, self._device
11871184
)

0 commit comments

Comments
 (0)