Skip to content

Commit 03c041a

Browse files
jd7-trfacebook-github-bot
authored andcommitted
Update docstrings for stride_per_key_per_rank (#3120)
Summary: Pull Request resolved: #3120 Created from CodeHub with https://fburl.com/edit-in-codehub Reviewed By: TroyGarden Differential Revision: D76999897 fbshipit-source-id: ace2321cb2653ecb2e91ddbaa115d0b54b1882a0
1 parent f659b6a commit 03c041a

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1679,9 +1679,9 @@ class KeyedJaggedTensor(Pipelineable, metaclass=JaggedTensorMeta):
16791679
offsets (Optional[torch.Tensor]): jagged slices, represented as cumulative
16801680
offsets.
16811681
stride (Optional[int]): number of examples per batch.
1682-
stride_per_key_per_rank (Optional[List[List[int]]]): batch size
1683-
(number of examples) per key per rank, with the outer list representing the
1684-
keys and the inner list representing the values.
1682+
stride_per_key_per_rank (Optional[Union[torch.IntTensor, List[List[int]]]]):
1683+
batch size (number of examples) per key per rank, with the outer list
1684+
representing the keys and the inner list representing the values.
16851685
Each value in the inner list represents the number of examples in the batch
16861686
from the rank of its index in a distributed context.
16871687
length_per_key (Optional[List[int]]): start length for each key.

0 commit comments

Comments
 (0)