-
Notifications
You must be signed in to change notification settings - Fork 308
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Trying to use a learning rate scheduler with torchao AdamW raises "lr was changed to a non-Tensor object" RuntimeError:
RuntimeError Traceback (most recent call last)
Cell In[36], line 20
18 loss = 4 - z
19 loss.backward()
---> 20 optimizer.step()
21 lr_scheduler.step()
22 optimizer.zero_grad()
File [/mnt/DataSSD/AI/Apps/ipex/venv/lib/python3.12/site-packages/torch/optim/lr_scheduler.py:124](http://localhost:8888/mnt/DataSSD/AI/Apps/ipex/venv/lib/python3.12/site-packages/torch/optim/lr_scheduler.py#line=123), in LRScheduler.__init__.<locals>.patch_track_step_called.<locals>.wrap_step.<locals>.wrapper(*args, **kwargs)
122 opt = opt_ref()
123 opt._opt_called = True # type: ignore[union-attr]
--> 124 return func.__get__(opt, opt.__class__)(*args, **kwargs)
File [/mnt/DataSSD/AI/Apps/ipex/venv/lib/python3.12/site-packages/torch/optim/optimizer.py:485](http://localhost:8888/mnt/DataSSD/AI/Apps/ipex/venv/lib/python3.12/site-packages/torch/optim/optimizer.py#line=484), in Optimizer.profile_hook_step.<locals>.wrapper(*args, **kwargs)
480 else:
481 raise RuntimeError(
482 f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
483 )
--> 485 out = func(*args, **kwargs)
486 self._optimizer_step_code()
488 # call optimizer step post hooks
File [/mnt/DataSSD/AI/Apps/ipex/venv/lib/python3.12/site-packages/torch/utils/_contextlib.py:116](http://localhost:8888/mnt/DataSSD/AI/Apps/ipex/venv/lib/python3.12/site-packages/torch/utils/_contextlib.py#line=115), in context_decorator.<locals>.decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
File [/mnt/DataSSD/AI/Apps/ipex/venv/lib/python3.12/site-packages/torchao/optim/adam.py:125](http://localhost:8888/mnt/DataSSD/AI/Apps/ipex/venv/lib/python3.12/site-packages/torchao/optim/adam.py#line=124), in _AdamBase.step(self, closure)
122 state["step"] += 1
124 if not isinstance(group["lr"], Tensor):
--> 125 raise RuntimeError(
126 "lr was changed to a non-Tensor object. If you want to update lr, please use "
127 "optim.param_groups[0]['lr'].fill_(new_lr)"
128 )
130 # without calling p.detach(), torch.compile() will have issues with FSDP2 in some cases
131 # https://github.com/pytorch/ao/issues/652#issuecomment-2285040894
132 # thus, by calling p.detach(), DTensor won't have .grad anymore, which is ok since we
133 # are passing grad separately anyway.
134 torch.compile(single_param_adam, fullgraph=True, dynamic=False)(
135 p.detach(),
136 grad,
(...)
147 self.bf16_stochastic_round and p.dtype is torch.bfloat16,
148 )
RuntimeError: lr was changed to a non-Tensor object. If you want to update lr, please use optim.param_groups[0]['lr'].fill_(new_lr)
Changing learning rates from floats to a tensor works but this completely breaks the learning rate scheduler:

This is the expected outcome from standard AdamW:

Minimal code to reproduce the issue:
import torch
import torchao
import matplotlib.pyplot as plt
x = torch.nn.Parameter(torch.tensor(1.0))
y = torch.nn.Parameter(torch.tensor(1.0))
optimizer = torchao.optim._AdamW([{"params": [x], "lr": 4e-6}, {"params": [y], "lr": 1e-5}])
#optimizer = torch.optim.AdamW([{"params": [x], "lr": 4e-6}, {"params": [y], "lr": 1e-5}])
#optimizer = torchao.optim._AdamW([{"params": [x], "lr": torch.tensor(4e-6)}, {"params": [y], "lr": torch.tensor(1e-5)}])
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1024)
lrs = []
lrs2 = []
for i in range(1024):
z = x * y
loss = 4 - z
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
last_lr = lr_scheduler.get_last_lr()
lrs.append(last_lr[0])
lrs2.append(last_lr[1])
print(lrs[0], lrs[-1])
print(lrs2[0], lrs2[-1])
plt.plot(lrs)
plt.plot(lrs2)
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working