Skip to content

Commit 0102122

Browse files
spcypptfacebook-github-bot
authored andcommitted
Replace LR access with wrapper (#2832)
Summary: Pull Request resolved: #2832 X-link: pytorch/FBGEMM#3849 X-link: facebookresearch/FBGEMM#937 Currently, `learning_rate` has been accessed directly through `optimizer_args.learning_rate`. Hence, any changes to `learning_rate` will be affected. This diff adds a wrapper function for accessing `learning_rate`. **Usage** ``` emb_op = SplitTableBatchedEmbeddingBagsCodegen(....) lr = emb_op.get_learning_rate() ``` We plan to remove `learning_rate` from `optimizer_args` to avoid recompilation in PT2. - PT2 adds a guard on the float inputs, and if the value is changed, it will be recompiled. - Re-compilation is expensive - Especially during warm up stage in e2e training, learning rate changes gradually for each iteration. If there is 10k warm up step, it recompiles 10k times. - Hence, we cannot keep `learning_rate` as float More context in D65511904. Reviewed By: sryap, TroyGarden Differential Revision: D71444136 fbshipit-source-id: e175eecac4c554c9e436b8dff8c9a7eb28b6578e
1 parent 76446e7 commit 0102122

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def __init__(
190190
state: Dict[Any, Any] = {}
191191
param_group: Dict[str, Any] = {
192192
"params": [],
193-
"lr": emb_module.optimizer_args.learning_rate,
193+
"lr": emb_module.get_learning_rate(),
194194
}
195195

196196
params: Dict[str, Union[torch.Tensor, ShardedTensor]] = {}
@@ -383,7 +383,7 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
383383
state: Dict[Any, Any] = {}
384384
param_group: Dict[str, Any] = {
385385
"params": [],
386-
"lr": emb_module.optimizer_args.learning_rate,
386+
"lr": emb_module.get_learning_rate(),
387387
}
388388

389389
params: Dict[str, Union[torch.Tensor, ShardedTensor]] = {}

torchrec/modules/fused_embedding_modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__( # noqa C901
6868
state: Dict[Any, Any] = {}
6969
param_group: Dict[str, Any] = {
7070
"params": [],
71-
"lr": emb_module.optimizer_args.learning_rate,
71+
"lr": emb_module.get_learning_rate(),
7272
}
7373

7474
params: Dict[str, torch.Tensor] = {}

0 commit comments

Comments
 (0)