@@ -211,42 +211,6 @@ class ShardParams:
211
211
local_metadata : List [ShardMetadata ]
212
212
embedding_weights : List [torch .Tensor ]
213
213
214
- def get_optimizer_single_value_shard_metadata_and_global_metadata (
215
- table_global_metadata : ShardedTensorMetadata ,
216
- optimizer_state : torch .Tensor ,
217
- ) -> Tuple [Dict [ShardMetadata , ShardMetadata ], ShardedTensorMetadata ]:
218
- table_global_shards_metadata : List [ShardMetadata ] = (
219
- table_global_metadata .shards_metadata
220
- )
221
-
222
- table_shard_metadata_to_optimizer_shard_metadata = {}
223
- for offset , table_shard_metadata in enumerate (table_global_shards_metadata ):
224
- table_shard_metadata_to_optimizer_shard_metadata [
225
- table_shard_metadata
226
- ] = ShardMetadata (
227
- shard_sizes = [1 ], # single value optimizer state
228
- shard_offsets = [offset ], # offset increases by 1 for each shard
229
- placement = table_shard_metadata .placement ,
230
- )
231
-
232
- tensor_properties = TensorProperties (
233
- dtype = optimizer_state .dtype ,
234
- layout = optimizer_state .layout ,
235
- requires_grad = False ,
236
- )
237
- single_value_optimizer_st_metadata = ShardedTensorMetadata (
238
- shards_metadata = list (
239
- table_shard_metadata_to_optimizer_shard_metadata .values ()
240
- ),
241
- size = torch .Size ([len (table_global_shards_metadata )]),
242
- tensor_properties = tensor_properties ,
243
- )
244
-
245
- return (
246
- table_shard_metadata_to_optimizer_shard_metadata ,
247
- single_value_optimizer_st_metadata ,
248
- )
249
-
250
214
def get_optimizer_rowwise_shard_metadata_and_global_metadata (
251
215
table_global_metadata : ShardedTensorMetadata ,
252
216
optimizer_state : torch .Tensor ,
@@ -392,10 +356,7 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
392
356
if optimizer_states :
393
357
optimizer_state_values = tuple (optimizer_states .values ())
394
358
for optimizer_state_value in optimizer_state_values :
395
- assert (
396
- table_config .local_rows == optimizer_state_value .size (0 )
397
- or optimizer_state_value .nelement () == 1 # single value state
398
- )
359
+ assert table_config .local_rows == optimizer_state_value .size (0 )
399
360
optimizer_states_keys_by_table [table_config .name ] = list (
400
361
optimizer_states .keys ()
401
362
)
@@ -474,35 +435,29 @@ def get_sharded_optim_state(momentum_idx: int) -> ShardedTensor:
474
435
momentum_local_shards : List [Shard ] = []
475
436
optimizer_sharded_tensor_metadata : ShardedTensorMetadata
476
437
477
- optim_state = shard_params .optimizer_states [0 ][momentum_idx - 1 ] # pyre-ignore[16]
478
- if optim_state .nelement () == 1 :
479
- # single value state: one value per table
480
- (
481
- table_shard_metadata_to_optimizer_shard_metadata ,
482
- optimizer_sharded_tensor_metadata ,
483
- ) = get_optimizer_single_value_shard_metadata_and_global_metadata (
484
- table_config .global_metadata ,
485
- optim_state ,
486
- )
487
- elif optim_state .dim () == 1 :
488
- # rowwise state: param.shape[0] == state.shape[0], state.shape[1] == 1
438
+ is_rowwise_optimizer_state : bool = (
439
+ # pyre-ignore
440
+ shard_params .optimizer_states [0 ][momentum_idx - 1 ].dim ()
441
+ == 1
442
+ )
443
+
444
+ if is_rowwise_optimizer_state :
489
445
(
490
446
table_shard_metadata_to_optimizer_shard_metadata ,
491
447
optimizer_sharded_tensor_metadata ,
492
448
) = get_optimizer_rowwise_shard_metadata_and_global_metadata (
493
449
table_config .global_metadata ,
494
- optim_state ,
450
+ shard_params . optimizer_states [ 0 ][ momentum_idx - 1 ] ,
495
451
sharding_dim ,
496
452
is_grid_sharded ,
497
453
)
498
454
else :
499
- # pointwise state: param.shape == state.shape
500
455
(
501
456
table_shard_metadata_to_optimizer_shard_metadata ,
502
457
optimizer_sharded_tensor_metadata ,
503
458
) = get_optimizer_pointwise_shard_metadata_and_global_metadata (
504
459
table_config .global_metadata ,
505
- optim_state ,
460
+ shard_params . optimizer_states [ 0 ][ momentum_idx - 1 ] ,
506
461
)
507
462
508
463
for optimizer_state , table_shard_local_metadata in zip (
0 commit comments