Skip to content

Commit f36d26d

Browse files
Microvefacebook-github-bot
authored andcommitted
Back out "JaggedTensor permute - less CPU ops" (#2862)
Summary: Pull Request resolved: #2862 Original commit changeset: 257a9a45b204 Original Phabricator Diff: D70609204 The original diff causes our integration tests failing: https://fburl.com/mlhub/q87huj8u Example of failed test job: f714668378 The error message is "AssertionError: inverse indices must be provided from KJT if using variable batch size per feature." Backout this diff to unblock our tests Reviewed By: dshi7 Differential Revision: D72187365 fbshipit-source-id: 58ff117c3583af7427eee2193cb510e3e0ea79dc
1 parent fa11751 commit f36d26d

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2517,23 +2517,22 @@ def permute(
25172517
permuted_stride_per_key_per_rank: List[List[int]] = []
25182518
permuted_length_per_key: List[int] = []
25192519
permuted_length_per_key_sum = 0
2520-
keys = self._keys
2521-
variable_stride_per_key = self.variable_stride_per_key()
2522-
stride_per_key_per_rank = self.stride_per_key_per_rank()
25232520
for index in indices:
2524-
key = keys[index]
2521+
key = self.keys()[index]
25252522
permuted_keys.append(key)
25262523
permuted_length_per_key.append(length_per_key[index])
2527-
if variable_stride_per_key:
2528-
permuted_stride_per_key_per_rank.append(stride_per_key_per_rank[index])
2524+
if self.variable_stride_per_key():
2525+
permuted_stride_per_key_per_rank.append(
2526+
self.stride_per_key_per_rank()[index]
2527+
)
25292528

25302529
permuted_length_per_key_sum = sum(permuted_length_per_key)
25312530
if not torch.jit.is_scripting() and is_non_strict_exporting():
25322531
torch._check_is_size(permuted_length_per_key_sum)
25332532
torch._check(permuted_length_per_key_sum != -1)
25342533
torch._check(permuted_length_per_key_sum != 0)
25352534

2536-
if variable_stride_per_key:
2535+
if self.variable_stride_per_key():
25372536
length_per_key_tensor = _pin_and_move(
25382537
torch.tensor(self.length_per_key()), self.device()
25392538
)
@@ -2578,7 +2577,7 @@ def permute(
25782577
permuted_length_per_key_sum,
25792578
)
25802579
stride_per_key_per_rank = (
2581-
permuted_stride_per_key_per_rank if variable_stride_per_key else None
2580+
permuted_stride_per_key_per_rank if self.variable_stride_per_key() else None
25822581
)
25832582
kjt = KeyedJaggedTensor(
25842583
keys=permuted_keys,

0 commit comments

Comments
 (0)