Skip to content

Commit 1a37915

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
jagged_tensor minor refactoring (#2860)
Summary: Pull Request resolved: #2860 # context * split test_jagged_tensor into multiple files to increase readability * use hypothesis.strategies to generate the test cases (the standard approach) * add some comments in jagged_tensor.py * use `torch.diff` to replace `_to_lengths` function. Reviewed By: ezyang, aporialiao Differential Revision: D56177133 fbshipit-source-id: 245536e7a1d10eee132a4e87b5813840249b793b
1 parent 679f2fc commit 1a37915

File tree

5 files changed

+2300
-2241
lines changed

5 files changed

+2300
-2241
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@
5252

5353

5454
def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
55+
"""
56+
moving a tensor from cpu to cuda using pinned memory (non_blocking) is generally faster
57+
"""
5558
if is_torchdynamo_compiling():
5659
# TODO: remove once FakeTensor supports pin_memory() and to(..., non_blocking=True)
5760
return tensor.to(device=device)
@@ -64,6 +67,9 @@ def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
6467

6568

6669
def _cumsum(o: List[int]) -> List[int]:
70+
"""
71+
python-list version of converting lengths --> offsets
72+
"""
6773
ret = [0] * (len(o) + 1)
6874
for i in range(len(o)):
6975
ret[i + 1] = ret[i] + o[i]
@@ -92,7 +98,7 @@ def _maybe_compute_lengths(
9298
) -> torch.Tensor:
9399
if lengths is None:
94100
assert offsets is not None
95-
lengths = _to_lengths(offsets)
101+
lengths = torch.diff(offsets)
96102
return lengths
97103

98104

0 commit comments

Comments
 (0)