Skip to content

Cannot use LR Schedulers with param groups on AdamW #2574

@Disty0

Description

@Disty0

Trying to use a learning rate scheduler with torchao AdamW raises "lr was changed to a non-Tensor object" RuntimeError:

RuntimeError                              Traceback (most recent call last)
Cell In[36], line 20
     18 loss = 4 - z
     19 loss.backward()
---> 20 optimizer.step()
     21 lr_scheduler.step()
     22 optimizer.zero_grad()

File [/mnt/DataSSD/AI/Apps/ipex/venv/lib/python3.12/site-packages/torch/optim/lr_scheduler.py:124](http://localhost:8888/mnt/DataSSD/AI/Apps/ipex/venv/lib/python3.12/site-packages/torch/optim/lr_scheduler.py#line=123), in LRScheduler.__init__.<locals>.patch_track_step_called.<locals>.wrap_step.<locals>.wrapper(*args, **kwargs)
    122 opt = opt_ref()
    123 opt._opt_called = True  # type: ignore[union-attr]
--> 124 return func.__get__(opt, opt.__class__)(*args, **kwargs)

File [/mnt/DataSSD/AI/Apps/ipex/venv/lib/python3.12/site-packages/torch/optim/optimizer.py:485](http://localhost:8888/mnt/DataSSD/AI/Apps/ipex/venv/lib/python3.12/site-packages/torch/optim/optimizer.py#line=484), in Optimizer.profile_hook_step.<locals>.wrapper(*args, **kwargs)
    480         else:
    481             raise RuntimeError(
    482                 f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
    483             )
--> 485 out = func(*args, **kwargs)
    486 self._optimizer_step_code()
    488 # call optimizer step post hooks

File [/mnt/DataSSD/AI/Apps/ipex/venv/lib/python3.12/site-packages/torch/utils/_contextlib.py:116](http://localhost:8888/mnt/DataSSD/AI/Apps/ipex/venv/lib/python3.12/site-packages/torch/utils/_contextlib.py#line=115), in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File [/mnt/DataSSD/AI/Apps/ipex/venv/lib/python3.12/site-packages/torchao/optim/adam.py:125](http://localhost:8888/mnt/DataSSD/AI/Apps/ipex/venv/lib/python3.12/site-packages/torchao/optim/adam.py#line=124), in _AdamBase.step(self, closure)
    122 state["step"] += 1
    124 if not isinstance(group["lr"], Tensor):
--> 125     raise RuntimeError(
    126         "lr was changed to a non-Tensor object. If you want to update lr, please use "
    127         "optim.param_groups[0]['lr'].fill_(new_lr)"
    128     )
    130 # without calling p.detach(), torch.compile() will have issues with FSDP2 in some cases
    131 # https://github.com/pytorch/ao/issues/652#issuecomment-2285040894
    132 # thus, by calling p.detach(), DTensor won't have .grad anymore, which is ok since we
    133 # are passing grad separately anyway.
    134 torch.compile(single_param_adam, fullgraph=True, dynamic=False)(
    135     p.detach(),
    136     grad,
   (...)
    147     self.bf16_stochastic_round and p.dtype is torch.bfloat16,
    148 )

RuntimeError: lr was changed to a non-Tensor object. If you want to update lr, please use optim.param_groups[0]['lr'].fill_(new_lr)

Changing learning rates from floats to a tensor works but this completely breaks the learning rate scheduler:

Image

This is the expected outcome from standard AdamW:

Image

Minimal code to reproduce the issue:

import torch
import torchao
import matplotlib.pyplot as plt

x = torch.nn.Parameter(torch.tensor(1.0))
y = torch.nn.Parameter(torch.tensor(1.0))

optimizer = torchao.optim._AdamW([{"params": [x], "lr": 4e-6}, {"params": [y], "lr": 1e-5}])
#optimizer = torch.optim.AdamW([{"params": [x], "lr": 4e-6}, {"params": [y], "lr": 1e-5}])
#optimizer = torchao.optim._AdamW([{"params": [x], "lr": torch.tensor(4e-6)}, {"params": [y], "lr": torch.tensor(1e-5)}])

lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1024)

lrs = []
lrs2 = []
for i in range(1024):
    z = x * y
    loss = 4 - z
    loss.backward()
    optimizer.step()
    lr_scheduler.step()
    optimizer.zero_grad()
    last_lr = lr_scheduler.get_last_lr()
    lrs.append(last_lr[0])
    lrs2.append(last_lr[1])

print(lrs[0], lrs[-1])
print(lrs2[0], lrs2[-1])

plt.plot(lrs)
plt.plot(lrs2)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions