Skip to content

Commit afd5726

Browse files
Wang Zhoufacebook-github-bot
authored andcommitted
Back out "Add iter singular value into TBE optimizer state" (#2487)
Summary: Pull Request resolved: #2487 Backout torchrec changes in D63909559 to unblock MVAI Reviewed By: dragonxlwang Differential Revision: D64406709 fbshipit-source-id: 59ad4e25c567ba55d08447283d5bb7ee14f564d2
1 parent b6e784e commit afd5726

File tree

1 file changed

+10
-55
lines changed

1 file changed

+10
-55
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 10 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -211,42 +211,6 @@ class ShardParams:
211211
local_metadata: List[ShardMetadata]
212212
embedding_weights: List[torch.Tensor]
213213

214-
def get_optimizer_single_value_shard_metadata_and_global_metadata(
215-
table_global_metadata: ShardedTensorMetadata,
216-
optimizer_state: torch.Tensor,
217-
) -> Tuple[Dict[ShardMetadata, ShardMetadata], ShardedTensorMetadata]:
218-
table_global_shards_metadata: List[ShardMetadata] = (
219-
table_global_metadata.shards_metadata
220-
)
221-
222-
table_shard_metadata_to_optimizer_shard_metadata = {}
223-
for offset, table_shard_metadata in enumerate(table_global_shards_metadata):
224-
table_shard_metadata_to_optimizer_shard_metadata[
225-
table_shard_metadata
226-
] = ShardMetadata(
227-
shard_sizes=[1], # single value optimizer state
228-
shard_offsets=[offset], # offset increases by 1 for each shard
229-
placement=table_shard_metadata.placement,
230-
)
231-
232-
tensor_properties = TensorProperties(
233-
dtype=optimizer_state.dtype,
234-
layout=optimizer_state.layout,
235-
requires_grad=False,
236-
)
237-
single_value_optimizer_st_metadata = ShardedTensorMetadata(
238-
shards_metadata=list(
239-
table_shard_metadata_to_optimizer_shard_metadata.values()
240-
),
241-
size=torch.Size([len(table_global_shards_metadata)]),
242-
tensor_properties=tensor_properties,
243-
)
244-
245-
return (
246-
table_shard_metadata_to_optimizer_shard_metadata,
247-
single_value_optimizer_st_metadata,
248-
)
249-
250214
def get_optimizer_rowwise_shard_metadata_and_global_metadata(
251215
table_global_metadata: ShardedTensorMetadata,
252216
optimizer_state: torch.Tensor,
@@ -392,10 +356,7 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
392356
if optimizer_states:
393357
optimizer_state_values = tuple(optimizer_states.values())
394358
for optimizer_state_value in optimizer_state_values:
395-
assert (
396-
table_config.local_rows == optimizer_state_value.size(0)
397-
or optimizer_state_value.nelement() == 1 # single value state
398-
)
359+
assert table_config.local_rows == optimizer_state_value.size(0)
399360
optimizer_states_keys_by_table[table_config.name] = list(
400361
optimizer_states.keys()
401362
)
@@ -474,35 +435,29 @@ def get_sharded_optim_state(momentum_idx: int) -> ShardedTensor:
474435
momentum_local_shards: List[Shard] = []
475436
optimizer_sharded_tensor_metadata: ShardedTensorMetadata
476437

477-
optim_state = shard_params.optimizer_states[0][momentum_idx - 1] # pyre-ignore[16]
478-
if optim_state.nelement() == 1:
479-
# single value state: one value per table
480-
(
481-
table_shard_metadata_to_optimizer_shard_metadata,
482-
optimizer_sharded_tensor_metadata,
483-
) = get_optimizer_single_value_shard_metadata_and_global_metadata(
484-
table_config.global_metadata,
485-
optim_state,
486-
)
487-
elif optim_state.dim() == 1:
488-
# rowwise state: param.shape[0] == state.shape[0], state.shape[1] == 1
438+
is_rowwise_optimizer_state: bool = (
439+
# pyre-ignore
440+
shard_params.optimizer_states[0][momentum_idx - 1].dim()
441+
== 1
442+
)
443+
444+
if is_rowwise_optimizer_state:
489445
(
490446
table_shard_metadata_to_optimizer_shard_metadata,
491447
optimizer_sharded_tensor_metadata,
492448
) = get_optimizer_rowwise_shard_metadata_and_global_metadata(
493449
table_config.global_metadata,
494-
optim_state,
450+
shard_params.optimizer_states[0][momentum_idx - 1],
495451
sharding_dim,
496452
is_grid_sharded,
497453
)
498454
else:
499-
# pointwise state: param.shape == state.shape
500455
(
501456
table_shard_metadata_to_optimizer_shard_metadata,
502457
optimizer_sharded_tensor_metadata,
503458
) = get_optimizer_pointwise_shard_metadata_and_global_metadata(
504459
table_config.global_metadata,
505-
optim_state,
460+
shard_params.optimizer_states[0][momentum_idx - 1],
506461
)
507462

508463
for optimizer_state, table_shard_local_metadata in zip(

0 commit comments

Comments
 (0)