You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Update the _maybe_compute_stride_kjt logic to calculate stride based off of inverse_indices for VBE KJTs. (#2925)
Summary:
See D73051959 for context.
Update the `_maybe_compute_stride_kjt` logic to calculate stride based off of `inverse_indices` for VBE KJTs.
Currently, stride of VBE KJT with `stride_per_key_per_rank` is calculated as the max "stride per key". This is different from the batch size of the EBC output KeyedTensor which is based off of `inverse_indices`. This causes issues in IR module serialization: [debug doc](https://docs.google.com/document/d/1yQhI484cgVloSqIBPAeTQhzfb3ltjvMRiLaQDceHGOU/edit?tab=t.0#heading=h.c66chahhl8df).
Differential Revision: D73824764
0 commit comments