Skip to content

Commit 932b4ef

Browse files
jd7-trfacebook-github-bot
authored andcommitted
Update the _maybe_compute_stride_kjt logic to calculate stride based off of inverse_indices for VBE KJTs. (#2925)
Summary: See D73051959 for context. Update the `_maybe_compute_stride_kjt` logic to calculate stride based off of `inverse_indices` for VBE KJTs. Currently, stride of VBE KJT with `stride_per_key_per_rank` is calculated as the max "stride per key". This is different from the batch size of the EBC output KeyedTensor which is based off of `inverse_indices`. This causes issues in IR module serialization: [debug doc](https://docs.google.com/document/d/1yQhI484cgVloSqIBPAeTQhzfb3ltjvMRiLaQDceHGOU/edit?tab=t.0#heading=h.c66chahhl8df). Differential Revision: D73824764
1 parent ced6a20 commit 932b4ef

File tree

2 files changed

+5
-0
lines changed

2 files changed

+5
-0
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,10 +1099,13 @@ def _maybe_compute_stride_kjt(
10991099
lengths: Optional[torch.Tensor],
11001100
offsets: Optional[torch.Tensor],
11011101
stride_per_key_per_rank: Optional[List[List[int]]],
1102+
inverse_indices: Optional[Tuple[List[str], torch.Tensor]],
11021103
) -> int:
11031104
if stride is None:
11041105
if len(keys) == 0:
11051106
stride = 0
1107+
elif inverse_indices is not None and inverse_indices[1].numel() > 0:
1108+
return inverse_indices[1].shape[1]
11061109
elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0:
11071110
stride = max([sum(s) for s in stride_per_key_per_rank])
11081111
elif offsets is not None and offsets.numel() > 0:
@@ -2174,6 +2177,7 @@ def stride(self) -> int:
21742177
self._lengths,
21752178
self._offsets,
21762179
self._stride_per_key_per_rank,
2180+
self._inverse_indices,
21772181
)
21782182
self._stride = stride
21792183
return stride

torchrec/sparse/tests/test_keyed_jagged_tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,6 +1033,7 @@ def test_flatten_unflatten_with_vbe(self) -> None:
10331033
kjt.stride_per_key_per_rank(), unflattened_kjt.stride_per_key_per_rank()
10341034
)
10351035
self.assertEqual(kjt.inverse_indices(), unflattened_kjt.inverse_indices())
1036+
self.assertEqual(kjt.stride(), kjt.inverse_indices()[1].shape[1])
10361037

10371038

10381039
class TestKeyedJaggedTensorScripting(unittest.TestCase):

0 commit comments

Comments
 (0)