File tree Expand file tree Collapse file tree 2 files changed +11
-3
lines changed Expand file tree Collapse file tree 2 files changed +11
-3
lines changed Original file line number Diff line number Diff line change @@ -476,7 +476,8 @@ def get_sharded_optim_state(
476
476
momentum_local_shards : List [Shard ] = []
477
477
optimizer_sharded_tensor_metadata : ShardedTensorMetadata
478
478
479
- optim_state = shard_params .optimizer_states [0 ][momentum_idx - 1 ] # pyre-ignore[16]
479
+ # pyre-ignore [16]
480
+ optim_state = shard_params .optimizer_states [0 ][momentum_idx - 1 ]
480
481
if (
481
482
optim_state .nelement () == 1 and state_key != "momentum1"
482
483
): # special handling for backward compatibility, momentum1 is rowwise state for rowwise_adagrad
Original file line number Diff line number Diff line change @@ -93,6 +93,12 @@ def _gen_pipelines(
93
93
default = 8192 ,
94
94
help = "Batch size." ,
95
95
)
96
+ @click .option (
97
+ "--sharding_type" ,
98
+ type = ShardingType ,
99
+ default = ShardingType .TABLE_WISE ,
100
+ help = "ShardingType." ,
101
+ )
96
102
@click .option (
97
103
"--pooling_factor" ,
98
104
type = int ,
@@ -129,6 +135,7 @@ def main(
129
135
dim_emb : int ,
130
136
n_batches : int ,
131
137
batch_size : int ,
138
+ sharding_type : ShardingType ,
132
139
pooling_factor : int ,
133
140
input_type : str ,
134
141
pipeline : str ,
@@ -178,7 +185,7 @@ def main(
178
185
callable = runner ,
179
186
tables = tables ,
180
187
weighted_tables = weighted_tables ,
181
- sharding_type = ShardingType . TABLE_WISE .value ,
188
+ sharding_type = sharding_type .value ,
182
189
kernel_type = EmbeddingComputeKernel .FUSED .value ,
183
190
batches = batches ,
184
191
fused_params = {},
@@ -190,7 +197,7 @@ def main(
190
197
single_runner (
191
198
tables = tables ,
192
199
weighted_tables = weighted_tables ,
193
- sharding_type = ShardingType . TABLE_WISE .value ,
200
+ sharding_type = sharding_type .value ,
194
201
kernel_type = EmbeddingComputeKernel .FUSED .value ,
195
202
batches = batches ,
196
203
fused_params = {},
You can’t perform that action at this time.
0 commit comments