Skip to content

Commit f35498d

Browse files
colin2328facebook-github-bot
authored andcommitted
add in_backward_optimizer_filter to work with in_backward_optimizers (#315)
Summary: X-link: facebookresearch/dlrm#315 Pull Request resolved: #892 as title Reviewed By: lequytra, YLGH Differential Revision: D42009102 fbshipit-source-id: a51cbe59a66a1ce4d82f5c46d39eca89b326be7a
1 parent d107533 commit f35498d

File tree

2 files changed

+62
-1
lines changed

2 files changed

+62
-1
lines changed

torchrec/optim/optimizers.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,32 @@
77

88
#!/usr/bin/env python3
99

10-
from typing import Iterable
10+
from typing import Iterable, Iterator, Tuple
1111

1212
import torch
13+
from torch import nn
1314

1415
from torch.optim.optimizer import Optimizer
1516

1617

18+
def in_backward_optimizer_filter(
19+
named_parameters: Iterator[Tuple[str, nn.Parameter]], include: bool = False
20+
) -> Iterator[Tuple[str, nn.Parameter]]:
21+
"""
22+
Filters named_parameters for whether they are or or not params that use
23+
the in_backward_optimizer.
24+
25+
Args:
26+
named_parameters(Iterator[Tuple[str, nn.Parameter]]): named_parameters
27+
include(bool): If true, only yields params with in_backward_optimizer. If false, returns the outside set
28+
Defaults to include params that are not in_backward (False)
29+
"""
30+
for fqn, param in named_parameters:
31+
# TODO: change to _in_backward_optimizer
32+
if hasattr(param, "_overlapped_optimizer") == include:
33+
yield fqn, param
34+
35+
1736
class SGD(Optimizer):
1837
r"""
1938
Placeholder for SGD. This optimizer will not functionally run.

torchrec/optim/tests/test_optim.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import unittest
9+
10+
import torch
11+
12+
from torchrec.modules.embedding_configs import EmbeddingBagConfig
13+
from torchrec.modules.embedding_modules import EmbeddingBagCollection
14+
from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward
15+
from torchrec.optim.optimizers import in_backward_optimizer_filter
16+
17+
18+
class TestInBackwardOptimizerFilter(unittest.TestCase):
19+
def test_in_backward_optimizer_filter(self) -> None:
20+
ebc = EmbeddingBagCollection(
21+
tables=[
22+
EmbeddingBagConfig(
23+
name="t1", embedding_dim=4, num_embeddings=2, feature_names=["f1"]
24+
),
25+
EmbeddingBagConfig(
26+
name="t2", embedding_dim=4, num_embeddings=2, feature_names=["f2"]
27+
),
28+
]
29+
)
30+
apply_optimizer_in_backward(
31+
torch.optim.SGD,
32+
ebc.embedding_bags["t1"].parameters(),
33+
optimizer_kwargs={"lr": 1.0},
34+
)
35+
in_backward_params = dict(
36+
in_backward_optimizer_filter(ebc.named_parameters(), include=True)
37+
)
38+
non_in_backward_params = dict(
39+
in_backward_optimizer_filter(ebc.named_parameters(), include=False)
40+
)
41+
assert set(in_backward_params.keys()) == {"embedding_bags.t1.weight"}
42+
assert set(non_in_backward_params.keys()) == {"embedding_bags.t2.weight"}

0 commit comments

Comments
 (0)