Skip to content

Commit 5b3880b

Browse files
jd7-trfacebook-github-bot
authored andcommitted
Update the _maybe_compute_stride_kjt logic to calculate stride based off of inverse_indices for VBE KJTs.
Summary: See D73051959 for context. Update the `_maybe_compute_stride_kjt` logic to calculate stride based off of `inverse_indices` for VBE KJTs. Currently, stride of VBE KJT with `stride_per_key_per_rank` is calculated as the max "stride per key". This is different from the batch size of the EBC output KeyedTensor which is based off of `inverse_indices`. This causes issues in IR module serialization: [debug doc].(https://docs.google.com/document/d/1yQhI484cgVloSqIBPAeTQhzfb3ltjvMRiLaQDceHGOU/edit?tab=t.0#heading=h.c66chahhl8df). Differential Revision: D73824764
1 parent 0330c74 commit 5b3880b

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,10 +1099,13 @@ def _maybe_compute_stride_kjt(
10991099
lengths: Optional[torch.Tensor],
11001100
offsets: Optional[torch.Tensor],
11011101
stride_per_key_per_rank: Optional[List[List[int]]],
1102+
inverse_indices: Optional[Tuple[List[str], torch.Tensor]],
11021103
) -> int:
11031104
if stride is None:
11041105
if len(keys) == 0:
11051106
stride = 0
1107+
elif inverse_indices is not None and inverse_indices[1].numel() > 0:
1108+
return inverse_indices[1].shape[1]
11061109
elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0:
11071110
stride = max([sum(s) for s in stride_per_key_per_rank])
11081111
elif offsets is not None and offsets.numel() > 0:
@@ -2171,6 +2174,7 @@ def stride(self) -> int:
21712174
self._lengths,
21722175
self._offsets,
21732176
self._stride_per_key_per_rank,
2177+
self._inverse_indices,
21742178
)
21752179
self._stride = stride
21762180
return stride

0 commit comments

Comments
 (0)