Skip to content

Commit 3e8de05

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
add sharding_type argument to pipeline benchmark (#2495)
Summary: Pull Request resolved: #2495 # context * add sharding_type argument to the pipeline benchmark * better control of different sharding types Reviewed By: iamzainhuda Differential Revision: D64676132 fbshipit-source-id: 8ceeba667b6b2f8aded3dcf3d894cf8cbca31d3e
1 parent dbca437 commit 3e8de05

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,8 @@ def get_sharded_optim_state(
476476
momentum_local_shards: List[Shard] = []
477477
optimizer_sharded_tensor_metadata: ShardedTensorMetadata
478478

479-
optim_state = shard_params.optimizer_states[0][momentum_idx - 1] # pyre-ignore[16]
479+
# pyre-ignore [16]
480+
optim_state = shard_params.optimizer_states[0][momentum_idx - 1]
480481
if (
481482
optim_state.nelement() == 1 and state_key != "momentum1"
482483
): # special handling for backward compatibility, momentum1 is rowwise state for rowwise_adagrad

torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,12 @@ def _gen_pipelines(
9393
default=8192,
9494
help="Batch size.",
9595
)
96+
@click.option(
97+
"--sharding_type",
98+
type=ShardingType,
99+
default=ShardingType.TABLE_WISE,
100+
help="ShardingType.",
101+
)
96102
@click.option(
97103
"--pooling_factor",
98104
type=int,
@@ -129,6 +135,7 @@ def main(
129135
dim_emb: int,
130136
n_batches: int,
131137
batch_size: int,
138+
sharding_type: ShardingType,
132139
pooling_factor: int,
133140
input_type: str,
134141
pipeline: str,
@@ -178,7 +185,7 @@ def main(
178185
callable=runner,
179186
tables=tables,
180187
weighted_tables=weighted_tables,
181-
sharding_type=ShardingType.TABLE_WISE.value,
188+
sharding_type=sharding_type.value,
182189
kernel_type=EmbeddingComputeKernel.FUSED.value,
183190
batches=batches,
184191
fused_params={},
@@ -190,7 +197,7 @@ def main(
190197
single_runner(
191198
tables=tables,
192199
weighted_tables=weighted_tables,
193-
sharding_type=ShardingType.TABLE_WISE.value,
200+
sharding_type=sharding_type.value,
194201
kernel_type=EmbeddingComputeKernel.FUSED.value,
195202
batches=batches,
196203
fused_params={},

0 commit comments

Comments
 (0)