Skip to content

Commit a4d93cf

Browse files
committed
An impl of adafactor as per big vision (scaling vit) changes
1 parent d4dde48 commit a4d93cf

File tree

3 files changed

+292
-0
lines changed

3 files changed

+292
-0
lines changed

timm/optim/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .adamp import AdamP
55
from .adamw import AdamW
66
from .adan import Adan
7+
from .adafactor_bv import AdafactorBigVision
78
from .lamb import Lamb
89
from .lars import Lars
910
from .lookahead import Lookahead

timm/optim/adafactor_bv.py

Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
from typing import List, Optional, Tuple, Union
2+
3+
import numpy as np
4+
import torch
5+
from torch import Tensor
6+
from torch.optim import Optimizer
7+
8+
9+
def _get_scalar_dtype():
10+
"""Get the scalar dtype that the optimizer uses for state"""
11+
return torch.float64
12+
13+
def _factored_dims(
14+
shape: Tuple[int, ...],
15+
factored: bool,
16+
min_dim_size_to_factor: int
17+
) -> Optional[tuple[int, int]]:
18+
"""Whether to use a factored second moment estimator.
19+
20+
This function returns a tuple with the two largest axes to reduce over.
21+
If no two dimensions have size >= min_dim_size_to_factor, return None.
22+
23+
Args:
24+
shape: an input shape
25+
factored: whether to use factored second-moment estimator for > 2d vars.
26+
min_dim_size_to_factor: only factor accumulator if two array dimensions have at least this size.
27+
28+
Returns:
29+
None or a tuple of ints
30+
"""
31+
if not factored or len(shape) < 2:
32+
return None
33+
sorted_dims = np.argsort(shape)
34+
if shape[sorted_dims[-2]] < min_dim_size_to_factor:
35+
return None
36+
return int(sorted_dims[-2]), int(sorted_dims[-1])
37+
38+
39+
class AdafactorBigVision(Optimizer):
40+
"""
41+
PyTorch implementation of BigVision's Adafactor variant with both single and multi tensor implementations.
42+
43+
44+
45+
"""
46+
47+
def __init__(
48+
self,
49+
params,
50+
lr: float = 1.0,
51+
min_dim_size_to_factor: int = 32,
52+
decay_rate: float = 0.8,
53+
decay_offset: int = 0,
54+
beta2_cap: float = 0.999,
55+
momentum: Optional[float] = 0.9,
56+
momentum_dtype: Union[str, torch.dtype] = torch.bfloat16,
57+
eps: float = 1e-30,
58+
weight_decay: float = 0.0,
59+
clipping_threshold: Optional[float] = None,
60+
unscaled_wd: bool = False,
61+
*,
62+
foreach: Optional[bool] = False,
63+
):
64+
if isinstance(momentum_dtype, str):
65+
if momentum_dtype == 'float16':
66+
momentum_dtype = torch.float16
67+
elif momentum_dtype == 'bfloat16':
68+
momentum_dtype = torch.bfloat16
69+
else:
70+
assert momentum_dtype == 'float32', f'{momentum_dtype} dtype not supported'
71+
momentum_dtype = torch.float32
72+
73+
defaults = dict(
74+
lr=lr,
75+
min_dim_size_to_factor=min_dim_size_to_factor,
76+
decay_rate=decay_rate,
77+
decay_offset=decay_offset,
78+
beta2_cap=beta2_cap,
79+
momentum=momentum,
80+
momentum_dtype=momentum_dtype,
81+
eps=eps,
82+
weight_decay=weight_decay,
83+
clipping_threshold=clipping_threshold,
84+
unscaled_wd=unscaled_wd,
85+
foreach=foreach,
86+
)
87+
super().__init__(params, defaults)
88+
89+
def __setstate__(self, state):
90+
super().__setstate__(state)
91+
for group in self.param_groups:
92+
group.setdefault('foreach', None)
93+
for p in group['params']:
94+
p_state = self.state.get(p, {})
95+
if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
96+
p_state['step'] = torch.tensor(float(p_state['step']), dtype=_get_scalar_dtype())
97+
98+
def _get_beta2(self, step: Tensor, decay_rate: float, beta2_cap: float) -> float:
99+
"""Computes beta2 according to the step schedule"""
100+
t = float(step + 1)
101+
return min(beta2_cap, 1.0 - t ** (-decay_rate))
102+
103+
@torch.no_grad()
104+
def step(self, closure=None):
105+
loss = None
106+
if closure is not None:
107+
with torch.enable_grad():
108+
loss = closure()
109+
110+
for group in self.param_groups:
111+
params_with_grad = []
112+
grads = []
113+
exp_avg_sq_rs = []
114+
exp_avg_sq_cs = []
115+
exp_avg_sqs = []
116+
state_steps = []
117+
exp_avgs = [] # For momentum
118+
119+
for p in group['params']:
120+
if p.grad is None:
121+
continue
122+
123+
if p.grad.is_sparse:
124+
raise RuntimeError("Sparse gradients not supported")
125+
126+
params_with_grad.append(p)
127+
grads.append(p.grad)
128+
129+
state = self.state[p]
130+
131+
if len(state) == 0:
132+
# NOTE step on CPU, probably need some more though to make capturable
133+
state['step'] = torch.tensor(0.0, dtype=_get_scalar_dtype())
134+
135+
shape = p.grad.shape
136+
factored_dims = _factored_dims(
137+
shape,
138+
factored=True,
139+
min_dim_size_to_factor=self.defaults['min_dim_size_to_factor']
140+
)
141+
142+
if factored_dims is not None:
143+
d1, d0 = factored_dims
144+
row_shape = list(p.grad.shape)
145+
row_shape[d0] = 1
146+
col_shape = list(p.grad.shape)
147+
col_shape[d1] = 1
148+
state['exp_avg_sq_r'] = p.grad.new_zeros(row_shape)
149+
state['exp_avg_sq_c'] = p.grad.new_zeros(col_shape)
150+
else:
151+
state['exp_avg_sq'] = torch.zeros_like(p.grad, memory_format=torch.preserve_format)
152+
153+
if self.defaults['momentum'] is not None:
154+
state['exp_avg'] = torch.zeros_like(p.grad, dtype=torch.bfloat16)
155+
156+
state_steps.append(state['step'])
157+
exp_avg_sq_rs.append(state.get('exp_avg_sq_r', None))
158+
exp_avg_sq_cs.append(state.get('exp_avg_sq_c', None))
159+
exp_avg_sqs.append(state.get('exp_avg_sq', None))
160+
exp_avgs.append(state.get('exp_avg', None))
161+
162+
if group['foreach']:
163+
func = _multi_tensor_adafactor
164+
else:
165+
func = _single_tensor_adafactor
166+
167+
func(
168+
params=params_with_grad,
169+
grads=grads,
170+
exp_avg_sq_rs=exp_avg_sq_rs,
171+
exp_avg_sq_cs=exp_avg_sq_cs,
172+
exp_avg_sqs=exp_avg_sqs,
173+
exp_avgs=exp_avgs,
174+
state_steps=state_steps,
175+
beta2_decay=group['decay_rate'],
176+
beta2_cap=group['beta2_cap'],
177+
min_dim_size_to_factor=group['min_dim_size_to_factor'],
178+
eps=group['eps'],
179+
lr=group['lr'],
180+
weight_decay=group['weight_decay'],
181+
momentum=group['momentum'],
182+
momentum_dtype=group['momentum_dtype'],
183+
clipping_threshold=group['clipping_threshold'],
184+
unscaled_wd=group['unscaled_wd'],
185+
)
186+
187+
return loss
188+
189+
def _single_tensor_adafactor(
190+
params: List[Tensor],
191+
grads: List[Tensor],
192+
exp_avg_sq_rs: List[Optional[Tensor]],
193+
exp_avg_sq_cs: List[Optional[Tensor]],
194+
exp_avg_sqs: List[Optional[Tensor]],
195+
exp_avgs: List[Optional[Tensor]],
196+
state_steps: List[Tensor],
197+
*,
198+
beta2_decay: float,
199+
beta2_cap: float,
200+
min_dim_size_to_factor: int,
201+
eps: float,
202+
lr: float,
203+
weight_decay: float,
204+
momentum: Optional[float],
205+
momentum_dtype: Union[str, torch.dtype],
206+
clipping_threshold: Optional[float],
207+
unscaled_wd: bool,
208+
):
209+
for i, param in enumerate(params):
210+
grad = grads[i]
211+
exp_avg_sq_r = exp_avg_sq_rs[i]
212+
exp_avg_sq_c = exp_avg_sq_cs[i]
213+
exp_avg_sq = exp_avg_sqs[i]
214+
exp_avg = exp_avgs[i]
215+
step_t = state_steps[i]
216+
217+
# Update step
218+
step_t += 1
219+
beta2_t = min(beta2_cap, 1.0 - float(step_t) ** (-beta2_decay))
220+
one_minus_beta2_t = 1 - beta2_t
221+
222+
if exp_avg_sq is None:
223+
d1, d0 = _factored_dims(grad.shape, True, min_dim_size_to_factor=min_dim_size_to_factor)
224+
grad_sqr = torch.square(grad) + eps
225+
exp_avg_sq_r.lerp_(grad_sqr.mean(dim=d0, keepdim=True), one_minus_beta2_t)
226+
exp_avg_sq_c.lerp_(grad_sqr.mean(dim=d1, keepdim=True), one_minus_beta2_t)
227+
228+
reduced_d1 = d1 - 1 if d1 > d0 else d1
229+
row_col_mean = exp_avg_sq_r.mean(dim=reduced_d1, keepdim=True)
230+
row_factor = (exp_avg_sq_r / row_col_mean).rsqrt()
231+
col_factor = exp_avg_sq_c.rsqrt()
232+
233+
update = grad * row_factor * col_factor
234+
else:
235+
# Handle non-factored
236+
exp_avg_sq.mul_(beta2_t).addcmul_(grad, grad, value=one_minus_beta2_t)
237+
update = grad * exp_avg_sq.add(eps).rsqrt_()
238+
239+
# Clip by RMS value
240+
if clipping_threshold is not None:
241+
denom = (update.norm(2) / ((update.numel() ** 0.5) / clipping_threshold)).clamp_(max=1.0)
242+
update.div_(denom)
243+
244+
# Apply momentum (in different dtype)
245+
if momentum is not None and exp_avg is not None:
246+
if momentum_dtype != grad.dtype:
247+
exp_avg.lerp_(update.to(momentum_dtype), 1 - momentum) # ema
248+
update = exp_avg.to(grad.dtype)
249+
else:
250+
exp_avg.lerp_(update, 1 - momentum) # ema
251+
update = exp_avg.clone()
252+
253+
# Scale by learning rate
254+
update.mul_(lr)
255+
256+
# Perform weight decay
257+
if weight_decay != 0:
258+
if unscaled_wd:
259+
# match big vision impl, 'fully decoupled' decay w/o LR scaling
260+
param.mul_(1. - weight_decay)
261+
else:
262+
# match typical pytorch behaviour for decoupled decay, eg adamw where wd is scaled by LR
263+
param.mul_(1. - lr * weight_decay)
264+
265+
# Update parameters
266+
param.add_(update, alpha=-1.0)
267+
268+
def _multi_tensor_adafactor(
269+
params: List[Tensor],
270+
grads: List[Tensor],
271+
exp_avg_sq_rs: List[Optional[Tensor]],
272+
exp_avg_sq_cs: List[Optional[Tensor]],
273+
exp_avg_sqs: List[Optional[Tensor]],
274+
exp_avgs: List[Optional[Tensor]],
275+
state_steps: List[Tensor],
276+
*,
277+
beta2_decay: float,
278+
beta2_cap: float,
279+
min_dim_size_to_factor: int,
280+
eps: float,
281+
lr: float,
282+
weight_decay: float,
283+
momentum: Optional[float],
284+
momentum_dtype: Union[str, torch.dtype],
285+
clipping_threshold: Optional[float],
286+
unscaled_wd: bool,
287+
):
288+
assert False, 'multi-tensor fn (foreach=True) not implemented yet'

timm/optim/optim_factory.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch.optim as optim
1111

1212
from timm.models import group_parameters
13+
from . import AdafactorBigVision
1314

1415
from .adabelief import AdaBelief
1516
from .adafactor import Adafactor
@@ -356,6 +357,8 @@ def create_optimizer_v2(
356357
elif opt_lower == 'lion':
357358
opt_args.pop('eps', None)
358359
optimizer = Lion(parameters, **opt_args)
360+
elif opt_lower == 'adafactorbv':
361+
optimizer = AdafactorBigVision(parameters, **opt_args)
359362

360363
# second order
361364
elif opt_lower == 'adahessian':

0 commit comments

Comments
 (0)