@@ -211,6 +211,42 @@ class ShardParams:
211
211
local_metadata : List [ShardMetadata ]
212
212
embedding_weights : List [torch .Tensor ]
213
213
214
+ def get_optimizer_single_value_shard_metadata_and_global_metadata (
215
+ table_global_metadata : ShardedTensorMetadata ,
216
+ optimizer_state : torch .Tensor ,
217
+ ) -> Tuple [Dict [ShardMetadata , ShardMetadata ], ShardedTensorMetadata ]:
218
+ table_global_shards_metadata : List [ShardMetadata ] = (
219
+ table_global_metadata .shards_metadata
220
+ )
221
+
222
+ table_shard_metadata_to_optimizer_shard_metadata = {}
223
+ for offset , table_shard_metadata in enumerate (table_global_shards_metadata ):
224
+ table_shard_metadata_to_optimizer_shard_metadata [
225
+ table_shard_metadata
226
+ ] = ShardMetadata (
227
+ shard_sizes = [1 ], # single value optimizer state
228
+ shard_offsets = [offset ], # offset increases by 1 for each shard
229
+ placement = table_shard_metadata .placement ,
230
+ )
231
+
232
+ tensor_properties = TensorProperties (
233
+ dtype = optimizer_state .dtype ,
234
+ layout = optimizer_state .layout ,
235
+ requires_grad = False ,
236
+ )
237
+ single_value_optimizer_st_metadata = ShardedTensorMetadata (
238
+ shards_metadata = list (
239
+ table_shard_metadata_to_optimizer_shard_metadata .values ()
240
+ ),
241
+ size = torch .Size ([len (table_global_shards_metadata )]),
242
+ tensor_properties = tensor_properties ,
243
+ )
244
+
245
+ return (
246
+ table_shard_metadata_to_optimizer_shard_metadata ,
247
+ single_value_optimizer_st_metadata ,
248
+ )
249
+
214
250
def get_optimizer_rowwise_shard_metadata_and_global_metadata (
215
251
table_global_metadata : ShardedTensorMetadata ,
216
252
optimizer_state : torch .Tensor ,
@@ -356,7 +392,10 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
356
392
if optimizer_states :
357
393
optimizer_state_values = tuple (optimizer_states .values ())
358
394
for optimizer_state_value in optimizer_state_values :
359
- assert table_config .local_rows == optimizer_state_value .size (0 )
395
+ assert (
396
+ table_config .local_rows == optimizer_state_value .size (0 )
397
+ or optimizer_state_value .nelement () == 1 # single value state
398
+ )
360
399
optimizer_states_keys_by_table [table_config .name ] = list (
361
400
optimizer_states .keys ()
362
401
)
@@ -430,34 +469,44 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
430
469
opt_state is not None for opt_state in shard_params .optimizer_states
431
470
):
432
471
# pyre-ignore
433
- def get_sharded_optim_state (momentum_idx : int ) -> ShardedTensor :
472
+ def get_sharded_optim_state (
473
+ momentum_idx : int , state_key : str
474
+ ) -> ShardedTensor :
434
475
assert momentum_idx > 0
435
476
momentum_local_shards : List [Shard ] = []
436
477
optimizer_sharded_tensor_metadata : ShardedTensorMetadata
437
478
438
- is_rowwise_optimizer_state : bool = (
439
- # pyre-ignore
440
- shard_params .optimizer_states [0 ][momentum_idx - 1 ].dim ()
441
- == 1
442
- )
443
-
444
- if is_rowwise_optimizer_state :
479
+ optim_state = shard_params .optimizer_states [0 ][momentum_idx - 1 ] # pyre-ignore[16]
480
+ if (
481
+ optim_state .nelement () == 1 and state_key != "momentum1"
482
+ ): # special handling for backward compatibility, momentum1 is rowwise state for rowwise_adagrad
483
+ # single value state: one value per table
484
+ (
485
+ table_shard_metadata_to_optimizer_shard_metadata ,
486
+ optimizer_sharded_tensor_metadata ,
487
+ ) = get_optimizer_single_value_shard_metadata_and_global_metadata (
488
+ table_config .global_metadata ,
489
+ optim_state ,
490
+ )
491
+ elif optim_state .dim () == 1 :
492
+ # rowwise state: param.shape[0] == state.shape[0], state.shape[1] == 1
445
493
(
446
494
table_shard_metadata_to_optimizer_shard_metadata ,
447
495
optimizer_sharded_tensor_metadata ,
448
496
) = get_optimizer_rowwise_shard_metadata_and_global_metadata (
449
497
table_config .global_metadata ,
450
- shard_params . optimizer_states [ 0 ][ momentum_idx - 1 ] ,
498
+ optim_state ,
451
499
sharding_dim ,
452
500
is_grid_sharded ,
453
501
)
454
502
else :
503
+ # pointwise state: param.shape == state.shape
455
504
(
456
505
table_shard_metadata_to_optimizer_shard_metadata ,
457
506
optimizer_sharded_tensor_metadata ,
458
507
) = get_optimizer_pointwise_shard_metadata_and_global_metadata (
459
508
table_config .global_metadata ,
460
- shard_params . optimizer_states [ 0 ][ momentum_idx - 1 ] ,
509
+ optim_state ,
461
510
)
462
511
463
512
for optimizer_state , table_shard_local_metadata in zip (
@@ -499,7 +548,7 @@ def get_sharded_optim_state(momentum_idx: int) -> ShardedTensor:
499
548
cur_state_key = optimizer_state_keys [cur_state_idx ]
500
549
501
550
state [weight ][f"{ table_config .name } .{ cur_state_key } " ] = (
502
- get_sharded_optim_state (cur_state_idx + 1 )
551
+ get_sharded_optim_state (cur_state_idx + 1 , cur_state_key )
503
552
)
504
553
505
554
super ().__init__ (params , state , [param_group ])
0 commit comments