Skip to content

Commit d9cc3e0

Browse files
colin2328facebook-github-bot
authored andcommitted
in_backward_optimizer_filter on torchrec callsites (#899)
Summary: Pull Request resolved: #899 X-link: pytorch/torchsnapshot#134 D41964643 , but separate diff to mitigate target determinator Reviewed By: YLGH Differential Revision: D42061412 fbshipit-source-id: ea3e0f9aa1739ee718e1dea2664deadb12d44df0
1 parent 9f941ec commit d9cc3e0

File tree

9 files changed

+27
-12
lines changed

9 files changed

+27
-12
lines changed

contrib/dynamic_embedding/tests/test_integral_precision.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
18
import unittest
29

310
import torch
@@ -14,6 +21,7 @@
1421

1522
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
1623
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
24+
from torchrec.optim.optimizers import in_backward_optimizer_filter
1725

1826
from torchrec_dynamic_embedding.id_transformer_group import IDTransformerGroup
1927
from utils import init_dist, register_memory_io
@@ -93,7 +101,7 @@ def get_dmp(model):
93101
model = DMP(module=model, device=device, plan=plan, sharders=sharders)
94102

95103
dense_optimizer = KeyedOptimizerWrapper(
96-
dict(model.named_parameters()),
104+
dict(in_backward_optimizer_filter(model.named_parameters())),
97105
lambda params: torch.optim.Adam(params, lr=1e-1),
98106
)
99107
optimizer = CombinedOptimizer([model.fused_optimizer, dense_optimizer])

examples/bert4rec/bert4rec_main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from torchrec.distributed.model_parallel import DistributedModelParallel as DMP
2727
from torchrec.distributed.types import ModuleSharder
2828
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
29+
from torchrec.optim.optimizers import in_backward_optimizer_filter
2930
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
3031
from tqdm import tqdm
3132

@@ -497,11 +498,12 @@ def main(argv: List[str]) -> None:
497498
],
498499
)
499500
dense_optimizer = KeyedOptimizerWrapper(
500-
dict(model.named_parameters()),
501+
dict(in_backward_optimizer_filter(model.named_parameters())),
501502
lambda params: optim.Adam(
502503
params, lr=args.lr, weight_decay=args.weight_decay
503504
),
504505
)
506+
505507
optimizer = CombinedOptimizer([model.fused_optimizer, dense_optimizer])
506508
else:
507509
device_ids = [rank] if backend == "nccl" else None

examples/golden_training/train_dlrm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from torchrec.modules.embedding_modules import EmbeddingBagCollection
2828
from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward
2929
from torchrec.optim.keyed import KeyedOptimizerWrapper
30+
from torchrec.optim.optimizers import in_backward_optimizer_filter
3031
from torchrec.optim.rowwise_adagrad import RowWiseAdagrad
3132
from tqdm import tqdm
3233

@@ -132,7 +133,7 @@ def train(
132133
)
133134

134135
non_fused_optimizer = KeyedOptimizerWrapper(
135-
dict(model.named_parameters()),
136+
dict(in_backward_optimizer_filter(model.named_parameters())),
136137
lambda params: torch.optim.Adagrad(params, lr=learning_rate),
137138
)
138139
# Overlap comm/compute/device transfer during training through train_pipeline

examples/nvt_dataloader/train_torchrec.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
import torchrec
2020
import torchrec.distributed as trec_dist
2121
import torchrec.optim as trec_optim
22-
23-
from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType
2422
from nvt_binary_dataloader import NvtBinaryDataloader
2523
from pyre_extensions import none_throws
2624
from torchrec import EmbeddingBagCollection
@@ -40,6 +38,7 @@
4038
from torchrec.modules.embedding_configs import EmbeddingBagConfig
4139
from torchrec.modules.fused_embedding_modules import fuse_embedding_optimizer
4240
from torchrec.optim.keyed import KeyedOptimizerWrapper
41+
from torchrec.optim.optimizers import in_backward_optimizer_filter
4342

4443

4544
def parse_args(argv: List[str]) -> argparse.Namespace:
@@ -270,7 +269,7 @@ def main(argv: List[str]):
270269
)
271270

272271
non_fused_optimizer = KeyedOptimizerWrapper(
273-
dict(model.named_parameters()),
272+
dict(in_backward_optimizer_filter(model.named_parameters())),
274273
lambda params: torch.optim.Adagrad(params, lr=args.learning_rate)
275274
if args.adagrad
276275
else torch.optim.SGD(params, lr=args.learning_rate),

examples/ray/train_torchrec.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from torchrec.modules.embedding_configs import EmbeddingBagConfig
2525
from torchrec.modules.embedding_modules import EmbeddingBagCollection
2626
from torchrec.optim.keyed import KeyedOptimizerWrapper
27+
from torchrec.optim.optimizers import in_backward_optimizer_filter
2728
from tqdm import tqdm
2829

2930

@@ -110,7 +111,7 @@ def train(
110111

111112
# Overlap comm/compute/device transfer during training through train_pipeline
112113
non_fused_optimizer = KeyedOptimizerWrapper(
113-
dict(model.named_parameters()),
114+
dict(in_backward_optimizer_filter(model.named_parameters())),
114115
lambda params: torch.optim.Adagrad(params, lr=learning_rate),
115116
)
116117
train_pipeline = TrainPipelineSparseDist(

examples/torcharrow/run.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torchrec.models.dlrm import DLRM
2121
from torchrec.modules.embedding_configs import EmbeddingBagConfig
2222
from torchrec.optim.keyed import KeyedOptimizerWrapper
23+
from torchrec.optim.optimizers import in_backward_optimizer_filter
2324

2425

2526
@record
@@ -94,7 +95,7 @@ def main(
9495
)
9596

9697
optimizer = KeyedOptimizerWrapper(
97-
dict(model.named_parameters()),
98+
dict(in_backward_optimizer_filter(model.named_parameters())),
9899
lambda params: torch.optim.SGD(params, lr=0.01),
99100
)
100101

test_installation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torchrec.models.dlrm import DLRM
1919
from torchrec.modules.embedding_configs import EmbeddingBagConfig
2020
from torchrec.optim.keyed import KeyedOptimizerWrapper
21+
from torchrec.optim.optimizers import in_backward_optimizer_filter
2122

2223
if sys.platform not in ["linux", "linux2"]:
2324
raise EnvironmentError(
@@ -118,7 +119,7 @@ def main(argv: List[str]) -> None:
118119
device=device,
119120
)
120121
optimizer = KeyedOptimizerWrapper(
121-
dict(model.named_parameters()),
122+
dict(in_backward_optimizer_filter(model.named_parameters())),
122123
lambda params: torch.optim.SGD(params, lr=0.01),
123124
)
124125

torchrec/distributed/test_utils/test_sharding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from torchrec.modules.embedding_configs import BaseEmbeddingConfig, EmbeddingBagConfig
4444
from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward
4545
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
46+
from torchrec.optim.optimizers import in_backward_optimizer_filter
4647
from typing_extensions import Protocol
4748

4849

@@ -343,7 +344,7 @@ def sharding_single_rank_test(
343344
)
344345

345346
dense_optim = KeyedOptimizerWrapper(
346-
dict(local_model.named_parameters()),
347+
dict(in_backward_optimizer_filter(local_model.named_parameters())),
347348
lambda params: torch.optim.SGD(params, lr=0.1),
348349
)
349350
local_opt = CombinedOptimizer([local_model.fused_optimizer, dense_optim])

torchrec/distributed/tests/test_train_pipeline.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from torchrec.modules.embedding_modules import EmbeddingBagCollection
4545

4646
from torchrec.optim.keyed import KeyedOptimizerWrapper
47+
from torchrec.optim.optimizers import in_backward_optimizer_filter
4748
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
4849
from torchrec.streamable import Pipelineable
4950
from torchrec.test_utils import get_free_port, init_distributed_single_host
@@ -197,7 +198,7 @@ def _test_feature_processor_helper(
197198
copy_state_dict(unsharded_model.state_dict(), distributed_model.state_dict())
198199
optimizer_cpu = optim.SGD(unsharded_model.parameters(), lr=0.1)
199200
optimizer_distributed = KeyedOptimizerWrapper(
200-
dict(distributed_model.named_parameters()),
201+
dict(in_backward_optimizer_filter(distributed_model.named_parameters())),
201202
lambda params: optim.SGD(params, lr=0.1),
202203
)
203204
pipeline = TrainPipelineSparseDist(
@@ -289,7 +290,7 @@ def _test_move_cpu_gpu_helper(
289290
)
290291
optimizer_cpu = optim.SGD(model_cpu.parameters(), lr=0.1)
291292
optimizer_distributed = KeyedOptimizerWrapper(
292-
dict(distributed_model.named_parameters()),
293+
dict(in_backward_optimizer_filter(distributed_model.named_parameters())),
293294
lambda params: optim.SGD(params, lr=0.1),
294295
)
295296
pipeline = TrainPipelineSparseDist(

0 commit comments

Comments
 (0)