Skip to content

Commit d1a2990

Browse files
Wang Zhoufacebook-github-bot
authored andcommitted
TREC handling iter singular value optimizer state (#2491)
Summary: Pull Request resolved: #2491 reland D63909559 with special handling of `momentume1` for twrw sharding Reviewed By: dracifer, iamzainhuda Differential Revision: D64406941 fbshipit-source-id: 044c671f64f488120b4b54dfcd5e401302b45e23
1 parent 876222e commit d1a2990

File tree

1 file changed

+61
-12
lines changed

1 file changed

+61
-12
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 61 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,42 @@ class ShardParams:
211211
local_metadata: List[ShardMetadata]
212212
embedding_weights: List[torch.Tensor]
213213

214+
def get_optimizer_single_value_shard_metadata_and_global_metadata(
215+
table_global_metadata: ShardedTensorMetadata,
216+
optimizer_state: torch.Tensor,
217+
) -> Tuple[Dict[ShardMetadata, ShardMetadata], ShardedTensorMetadata]:
218+
table_global_shards_metadata: List[ShardMetadata] = (
219+
table_global_metadata.shards_metadata
220+
)
221+
222+
table_shard_metadata_to_optimizer_shard_metadata = {}
223+
for offset, table_shard_metadata in enumerate(table_global_shards_metadata):
224+
table_shard_metadata_to_optimizer_shard_metadata[
225+
table_shard_metadata
226+
] = ShardMetadata(
227+
shard_sizes=[1], # single value optimizer state
228+
shard_offsets=[offset], # offset increases by 1 for each shard
229+
placement=table_shard_metadata.placement,
230+
)
231+
232+
tensor_properties = TensorProperties(
233+
dtype=optimizer_state.dtype,
234+
layout=optimizer_state.layout,
235+
requires_grad=False,
236+
)
237+
single_value_optimizer_st_metadata = ShardedTensorMetadata(
238+
shards_metadata=list(
239+
table_shard_metadata_to_optimizer_shard_metadata.values()
240+
),
241+
size=torch.Size([len(table_global_shards_metadata)]),
242+
tensor_properties=tensor_properties,
243+
)
244+
245+
return (
246+
table_shard_metadata_to_optimizer_shard_metadata,
247+
single_value_optimizer_st_metadata,
248+
)
249+
214250
def get_optimizer_rowwise_shard_metadata_and_global_metadata(
215251
table_global_metadata: ShardedTensorMetadata,
216252
optimizer_state: torch.Tensor,
@@ -356,7 +392,10 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
356392
if optimizer_states:
357393
optimizer_state_values = tuple(optimizer_states.values())
358394
for optimizer_state_value in optimizer_state_values:
359-
assert table_config.local_rows == optimizer_state_value.size(0)
395+
assert (
396+
table_config.local_rows == optimizer_state_value.size(0)
397+
or optimizer_state_value.nelement() == 1 # single value state
398+
)
360399
optimizer_states_keys_by_table[table_config.name] = list(
361400
optimizer_states.keys()
362401
)
@@ -430,34 +469,44 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
430469
opt_state is not None for opt_state in shard_params.optimizer_states
431470
):
432471
# pyre-ignore
433-
def get_sharded_optim_state(momentum_idx: int) -> ShardedTensor:
472+
def get_sharded_optim_state(
473+
momentum_idx: int, state_key: str
474+
) -> ShardedTensor:
434475
assert momentum_idx > 0
435476
momentum_local_shards: List[Shard] = []
436477
optimizer_sharded_tensor_metadata: ShardedTensorMetadata
437478

438-
is_rowwise_optimizer_state: bool = (
439-
# pyre-ignore
440-
shard_params.optimizer_states[0][momentum_idx - 1].dim()
441-
== 1
442-
)
443-
444-
if is_rowwise_optimizer_state:
479+
optim_state = shard_params.optimizer_states[0][momentum_idx - 1] # pyre-ignore[16]
480+
if (
481+
optim_state.nelement() == 1 and state_key != "momentum1"
482+
): # special handling for backward compatibility, momentum1 is rowwise state for rowwise_adagrad
483+
# single value state: one value per table
484+
(
485+
table_shard_metadata_to_optimizer_shard_metadata,
486+
optimizer_sharded_tensor_metadata,
487+
) = get_optimizer_single_value_shard_metadata_and_global_metadata(
488+
table_config.global_metadata,
489+
optim_state,
490+
)
491+
elif optim_state.dim() == 1:
492+
# rowwise state: param.shape[0] == state.shape[0], state.shape[1] == 1
445493
(
446494
table_shard_metadata_to_optimizer_shard_metadata,
447495
optimizer_sharded_tensor_metadata,
448496
) = get_optimizer_rowwise_shard_metadata_and_global_metadata(
449497
table_config.global_metadata,
450-
shard_params.optimizer_states[0][momentum_idx - 1],
498+
optim_state,
451499
sharding_dim,
452500
is_grid_sharded,
453501
)
454502
else:
503+
# pointwise state: param.shape == state.shape
455504
(
456505
table_shard_metadata_to_optimizer_shard_metadata,
457506
optimizer_sharded_tensor_metadata,
458507
) = get_optimizer_pointwise_shard_metadata_and_global_metadata(
459508
table_config.global_metadata,
460-
shard_params.optimizer_states[0][momentum_idx - 1],
509+
optim_state,
461510
)
462511

463512
for optimizer_state, table_shard_local_metadata in zip(
@@ -499,7 +548,7 @@ def get_sharded_optim_state(momentum_idx: int) -> ShardedTensor:
499548
cur_state_key = optimizer_state_keys[cur_state_idx]
500549

501550
state[weight][f"{table_config.name}.{cur_state_key}"] = (
502-
get_sharded_optim_state(cur_state_idx + 1)
551+
get_sharded_optim_state(cur_state_idx + 1, cur_state_key)
503552
)
504553

505554
super().__init__(params, state, [param_group])

0 commit comments

Comments
 (0)