Skip to content

Commit 12eb3bf

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
add stride into KJT pytree (#2587)
Summary: Pull Request resolved: #2587 # context * Previously for a KJT, only the following fields and `_keys` are stored in the pytree flatten specs. All other arguments/parameters would be derived accordingly. ``` _fields = [ "_values", "_weights", "_lengths", "_offsets", ] ``` * Particularly, the `stride` (int) of a KJT, which represents the `batch_size`, is computed by `_maybe_compute_stride_kjt`: ``` def _maybe_compute_stride_kjt( keys: List[str], stride: Optional[int], lengths: Optional[torch.Tensor], offsets: Optional[torch.Tensor], stride_per_key_per_rank: Optional[List[List[int]]], ) -> int: if stride is None: if len(keys) == 0: stride = 0 elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0: stride = max([sum(s) for s in stride_per_key_per_rank]) elif offsets is not None and offsets.numel() > 0: stride = (offsets.numel() - 1) // len(keys) elif lengths is not None: stride = lengths.numel() // len(keys) else: stride = 0 return stride ``` * The previously stored pytree flatten specs are enough if the `batch_size` is static, however, this no longer holds true in a variable batch size scenario, where the `stride_per_key_per_rank` is not `None`. * An example is that with `dedup_ebc`, where the actual batch_size is variable (depending on the dedup data), but the output of the ebc should always be the **true** `stride` (static). * During ir_export, the output shape will be calculated from `kjt.stride()` function, which would be incorrect if the pytree specs only contains the `keys`. * This diff adds the `stride` into the KJT pytree flatten/unflatten functions so that a fakified KJT would have the correct stride value. Reviewed By: PaulZhang12 Differential Revision: D66400821 fbshipit-source-id: 741a54c6ad5bae1646fef45b11cf62823b3ecc27
1 parent f126ded commit 12eb3bf

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

.github/scripts/install_libs.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,4 @@ elif [ "$CHANNEL" = "test" ]; then
2828
fi
2929

3030

31-
${CONDA_RUN} pip install importlib-metadata
31+
${CONDA_RUN} pip install importlib-metadata click PyYAML

requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
black
2+
click
23
cmake
34
fbgemm-gpu
45
hypothesis==6.70.1
6+
importlib-metadata
57
iopath
68
numpy
79
pandas
@@ -13,6 +15,7 @@ torchx
1315
tqdm
1416
usort
1517
parameterized
18+
PyYAML
1619

1720
# for tests
1821
# https://github.com/pytorch/pytorch/blob/b96b1e8cff029bb0a73283e6e7f6cc240313f1dc/requirements.txt#L3

0 commit comments

Comments
 (0)