Skip to content

Commit 22078ab

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
simplify the KJT.split function when segment is the original KJT (#3014)
Summary: Pull Request resolved: #3014 # context * in KJT.split function, when the segment == len(keys), the returned KJT contains the same data as the original KJT * however in the function it recreates a new one which introduces extra cost * this diff remove the redundent KJT creation # analysis * when segment == len(keys), start has to be zero so the stride_per_key_per_rank is the original one. * the following KJT init produces the same KJT as self ``` KeyedJaggedTensor( keys=self._keys, values=self._values, weights=self.weights_or_none(), lengths=self._lengths, offsets=self._offsets, stride=self._stride, stride_per_key_per_rank=stride_per_key_per_rank, stride_per_key=None, length_per_key=self._length_per_key, lengths_offset_per_key=None, offset_per_key=self._offset_per_key, index_per_key=self._index_per_key, jt_dict=self._jt_dict, inverse_indices=None, ) ``` Reviewed By: iamzainhuda Differential Revision: D70756397 fbshipit-source-id: e299c67792d4219a2220ef52335dcd9fc4cd100d
1 parent 3df4260 commit 22078ab

File tree

1 file changed

+1
-18
lines changed

1 file changed

+1
-18
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2354,24 +2354,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
23542354
)
23552355
if segment == len(self._keys):
23562356
# no torch slicing required
2357-
split_list.append(
2358-
KeyedJaggedTensor(
2359-
keys=self._keys,
2360-
values=self._values,
2361-
weights=self.weights_or_none(),
2362-
lengths=self._lengths,
2363-
offsets=self._offsets,
2364-
stride=self._stride,
2365-
stride_per_key_per_rank=stride_per_key_per_rank,
2366-
stride_per_key=None,
2367-
length_per_key=self._length_per_key,
2368-
lengths_offset_per_key=None,
2369-
offset_per_key=self._offset_per_key,
2370-
index_per_key=self._index_per_key,
2371-
jt_dict=self._jt_dict,
2372-
inverse_indices=None,
2373-
)
2374-
)
2357+
split_list.append(self)
23752358
elif segment == 0:
23762359
empty_int_list: List[int] = torch.jit.annotate(List[int], [])
23772360
split_list.append(

0 commit comments

Comments
 (0)