Skip to content

Commit 3df4260

Browse files
Arthi Sureshfacebook-github-bot
authored andcommitted
Fix shape issue with remapped indices in sharded ManagedCollisionEmbeddingCollection (#3010)
Summary: Pull Request resolved: #3010 Fixing shapes of remapped indices returned by ManagedCollisionEmbeddingCollection to be [B, 1] instead of [B], which causes an issue in SequenceEmbeddingsAllToAll (embedding_dim=a2a_sequence_embs_tensor.shape[1]) related to the tuple index being out of range. Reviewed By: cx-yin, kausv, xing-liu Differential Revision: D75251224 Privacy Context Container: L1292699 fbshipit-source-id: 5afa863f778a74f6fbb51a9afba71986eff8c7e2
1 parent d2a3e56 commit 3df4260

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

torchrec/distributed/mc_modules.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,9 @@ def _kjt_list_to_tensor_list(
671671
vals.append(feature_split.values() + offset)
672672
remapped_ids_ret.append(torch.cat(vals).view(-1, 1))
673673
else:
674-
remapped_ids_ret.append(kjt.values() + self._table_to_offset[tables[0]])
674+
remapped_ids_ret.append(
675+
(kjt.values() + self._table_to_offset[tables[0]]).unsqueeze(-1)
676+
)
675677
return remapped_ids_ret
676678

677679
def global_to_local_index(

0 commit comments

Comments
 (0)