|
| 1 | +from typing import List, Optional, Tuple, Union |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import torch |
| 5 | +from torch import Tensor |
| 6 | +from torch.optim import Optimizer |
| 7 | + |
| 8 | + |
| 9 | +def _get_scalar_dtype(): |
| 10 | + """Get the scalar dtype that the optimizer uses for state""" |
| 11 | + return torch.float64 |
| 12 | + |
| 13 | +def _factored_dims( |
| 14 | + shape: Tuple[int, ...], |
| 15 | + factored: bool, |
| 16 | + min_dim_size_to_factor: int |
| 17 | +) -> Optional[tuple[int, int]]: |
| 18 | + """Whether to use a factored second moment estimator. |
| 19 | +
|
| 20 | + This function returns a tuple with the two largest axes to reduce over. |
| 21 | + If no two dimensions have size >= min_dim_size_to_factor, return None. |
| 22 | +
|
| 23 | + Args: |
| 24 | + shape: an input shape |
| 25 | + factored: whether to use factored second-moment estimator for > 2d vars. |
| 26 | + min_dim_size_to_factor: only factor accumulator if two array dimensions have at least this size. |
| 27 | +
|
| 28 | + Returns: |
| 29 | + None or a tuple of ints |
| 30 | + """ |
| 31 | + if not factored or len(shape) < 2: |
| 32 | + return None |
| 33 | + sorted_dims = np.argsort(shape) |
| 34 | + if shape[sorted_dims[-2]] < min_dim_size_to_factor: |
| 35 | + return None |
| 36 | + return int(sorted_dims[-2]), int(sorted_dims[-1]) |
| 37 | + |
| 38 | + |
| 39 | +class AdafactorBigVision(Optimizer): |
| 40 | + """ |
| 41 | + PyTorch implementation of BigVision's Adafactor variant with both single and multi tensor implementations. |
| 42 | +
|
| 43 | +
|
| 44 | +
|
| 45 | + """ |
| 46 | + |
| 47 | + def __init__( |
| 48 | + self, |
| 49 | + params, |
| 50 | + lr: float = 1.0, |
| 51 | + min_dim_size_to_factor: int = 32, |
| 52 | + decay_rate: float = 0.8, |
| 53 | + decay_offset: int = 0, |
| 54 | + beta2_cap: float = 0.999, |
| 55 | + momentum: Optional[float] = 0.9, |
| 56 | + momentum_dtype: Union[str, torch.dtype] = torch.bfloat16, |
| 57 | + eps: float = 1e-30, |
| 58 | + weight_decay: float = 0.0, |
| 59 | + clipping_threshold: Optional[float] = None, |
| 60 | + unscaled_wd: bool = False, |
| 61 | + *, |
| 62 | + foreach: Optional[bool] = False, |
| 63 | + ): |
| 64 | + if isinstance(momentum_dtype, str): |
| 65 | + if momentum_dtype == 'float16': |
| 66 | + momentum_dtype = torch.float16 |
| 67 | + elif momentum_dtype == 'bfloat16': |
| 68 | + momentum_dtype = torch.bfloat16 |
| 69 | + else: |
| 70 | + assert momentum_dtype == 'float32', f'{momentum_dtype} dtype not supported' |
| 71 | + momentum_dtype = torch.float32 |
| 72 | + |
| 73 | + defaults = dict( |
| 74 | + lr=lr, |
| 75 | + min_dim_size_to_factor=min_dim_size_to_factor, |
| 76 | + decay_rate=decay_rate, |
| 77 | + decay_offset=decay_offset, |
| 78 | + beta2_cap=beta2_cap, |
| 79 | + momentum=momentum, |
| 80 | + momentum_dtype=momentum_dtype, |
| 81 | + eps=eps, |
| 82 | + weight_decay=weight_decay, |
| 83 | + clipping_threshold=clipping_threshold, |
| 84 | + unscaled_wd=unscaled_wd, |
| 85 | + foreach=foreach, |
| 86 | + ) |
| 87 | + super().__init__(params, defaults) |
| 88 | + |
| 89 | + def __setstate__(self, state): |
| 90 | + super().__setstate__(state) |
| 91 | + for group in self.param_groups: |
| 92 | + group.setdefault('foreach', None) |
| 93 | + for p in group['params']: |
| 94 | + p_state = self.state.get(p, {}) |
| 95 | + if len(p_state) != 0 and not torch.is_tensor(p_state['step']): |
| 96 | + p_state['step'] = torch.tensor(float(p_state['step']), dtype=_get_scalar_dtype()) |
| 97 | + |
| 98 | + def _get_beta2(self, step: Tensor, decay_rate: float, beta2_cap: float) -> float: |
| 99 | + """Computes beta2 according to the step schedule""" |
| 100 | + t = float(step + 1) |
| 101 | + return min(beta2_cap, 1.0 - t ** (-decay_rate)) |
| 102 | + |
| 103 | + @torch.no_grad() |
| 104 | + def step(self, closure=None): |
| 105 | + loss = None |
| 106 | + if closure is not None: |
| 107 | + with torch.enable_grad(): |
| 108 | + loss = closure() |
| 109 | + |
| 110 | + for group in self.param_groups: |
| 111 | + params_with_grad = [] |
| 112 | + grads = [] |
| 113 | + exp_avg_sq_rs = [] |
| 114 | + exp_avg_sq_cs = [] |
| 115 | + exp_avg_sqs = [] |
| 116 | + state_steps = [] |
| 117 | + exp_avgs = [] # For momentum |
| 118 | + |
| 119 | + for p in group['params']: |
| 120 | + if p.grad is None: |
| 121 | + continue |
| 122 | + |
| 123 | + if p.grad.is_sparse: |
| 124 | + raise RuntimeError("Sparse gradients not supported") |
| 125 | + |
| 126 | + params_with_grad.append(p) |
| 127 | + grads.append(p.grad) |
| 128 | + |
| 129 | + state = self.state[p] |
| 130 | + |
| 131 | + if len(state) == 0: |
| 132 | + # NOTE step on CPU, probably need some more though to make capturable |
| 133 | + state['step'] = torch.tensor(0.0, dtype=_get_scalar_dtype()) |
| 134 | + |
| 135 | + shape = p.grad.shape |
| 136 | + factored_dims = _factored_dims( |
| 137 | + shape, |
| 138 | + factored=True, |
| 139 | + min_dim_size_to_factor=self.defaults['min_dim_size_to_factor'] |
| 140 | + ) |
| 141 | + |
| 142 | + if factored_dims is not None: |
| 143 | + d1, d0 = factored_dims |
| 144 | + row_shape = list(p.grad.shape) |
| 145 | + row_shape[d0] = 1 |
| 146 | + col_shape = list(p.grad.shape) |
| 147 | + col_shape[d1] = 1 |
| 148 | + state['exp_avg_sq_r'] = p.grad.new_zeros(row_shape) |
| 149 | + state['exp_avg_sq_c'] = p.grad.new_zeros(col_shape) |
| 150 | + else: |
| 151 | + state['exp_avg_sq'] = torch.zeros_like(p.grad, memory_format=torch.preserve_format) |
| 152 | + |
| 153 | + if self.defaults['momentum'] is not None: |
| 154 | + state['exp_avg'] = torch.zeros_like(p.grad, dtype=torch.bfloat16) |
| 155 | + |
| 156 | + state_steps.append(state['step']) |
| 157 | + exp_avg_sq_rs.append(state.get('exp_avg_sq_r', None)) |
| 158 | + exp_avg_sq_cs.append(state.get('exp_avg_sq_c', None)) |
| 159 | + exp_avg_sqs.append(state.get('exp_avg_sq', None)) |
| 160 | + exp_avgs.append(state.get('exp_avg', None)) |
| 161 | + |
| 162 | + if group['foreach']: |
| 163 | + func = _multi_tensor_adafactor |
| 164 | + else: |
| 165 | + func = _single_tensor_adafactor |
| 166 | + |
| 167 | + func( |
| 168 | + params=params_with_grad, |
| 169 | + grads=grads, |
| 170 | + exp_avg_sq_rs=exp_avg_sq_rs, |
| 171 | + exp_avg_sq_cs=exp_avg_sq_cs, |
| 172 | + exp_avg_sqs=exp_avg_sqs, |
| 173 | + exp_avgs=exp_avgs, |
| 174 | + state_steps=state_steps, |
| 175 | + beta2_decay=group['decay_rate'], |
| 176 | + beta2_cap=group['beta2_cap'], |
| 177 | + min_dim_size_to_factor=group['min_dim_size_to_factor'], |
| 178 | + eps=group['eps'], |
| 179 | + lr=group['lr'], |
| 180 | + weight_decay=group['weight_decay'], |
| 181 | + momentum=group['momentum'], |
| 182 | + momentum_dtype=group['momentum_dtype'], |
| 183 | + clipping_threshold=group['clipping_threshold'], |
| 184 | + unscaled_wd=group['unscaled_wd'], |
| 185 | + ) |
| 186 | + |
| 187 | + return loss |
| 188 | + |
| 189 | +def _single_tensor_adafactor( |
| 190 | + params: List[Tensor], |
| 191 | + grads: List[Tensor], |
| 192 | + exp_avg_sq_rs: List[Optional[Tensor]], |
| 193 | + exp_avg_sq_cs: List[Optional[Tensor]], |
| 194 | + exp_avg_sqs: List[Optional[Tensor]], |
| 195 | + exp_avgs: List[Optional[Tensor]], |
| 196 | + state_steps: List[Tensor], |
| 197 | + *, |
| 198 | + beta2_decay: float, |
| 199 | + beta2_cap: float, |
| 200 | + min_dim_size_to_factor: int, |
| 201 | + eps: float, |
| 202 | + lr: float, |
| 203 | + weight_decay: float, |
| 204 | + momentum: Optional[float], |
| 205 | + momentum_dtype: Union[str, torch.dtype], |
| 206 | + clipping_threshold: Optional[float], |
| 207 | + unscaled_wd: bool, |
| 208 | +): |
| 209 | + for i, param in enumerate(params): |
| 210 | + grad = grads[i] |
| 211 | + exp_avg_sq_r = exp_avg_sq_rs[i] |
| 212 | + exp_avg_sq_c = exp_avg_sq_cs[i] |
| 213 | + exp_avg_sq = exp_avg_sqs[i] |
| 214 | + exp_avg = exp_avgs[i] |
| 215 | + step_t = state_steps[i] |
| 216 | + |
| 217 | + # Update step |
| 218 | + step_t += 1 |
| 219 | + beta2_t = min(beta2_cap, 1.0 - float(step_t) ** (-beta2_decay)) |
| 220 | + one_minus_beta2_t = 1 - beta2_t |
| 221 | + |
| 222 | + if exp_avg_sq is None: |
| 223 | + d1, d0 = _factored_dims(grad.shape, True, min_dim_size_to_factor=min_dim_size_to_factor) |
| 224 | + grad_sqr = torch.square(grad) + eps |
| 225 | + exp_avg_sq_r.lerp_(grad_sqr.mean(dim=d0, keepdim=True), one_minus_beta2_t) |
| 226 | + exp_avg_sq_c.lerp_(grad_sqr.mean(dim=d1, keepdim=True), one_minus_beta2_t) |
| 227 | + |
| 228 | + reduced_d1 = d1 - 1 if d1 > d0 else d1 |
| 229 | + row_col_mean = exp_avg_sq_r.mean(dim=reduced_d1, keepdim=True) |
| 230 | + row_factor = (exp_avg_sq_r / row_col_mean).rsqrt() |
| 231 | + col_factor = exp_avg_sq_c.rsqrt() |
| 232 | + |
| 233 | + update = grad * row_factor * col_factor |
| 234 | + else: |
| 235 | + # Handle non-factored |
| 236 | + exp_avg_sq.mul_(beta2_t).addcmul_(grad, grad, value=one_minus_beta2_t) |
| 237 | + update = grad * exp_avg_sq.add(eps).rsqrt_() |
| 238 | + |
| 239 | + # Clip by RMS value |
| 240 | + if clipping_threshold is not None: |
| 241 | + denom = (update.norm(2) / ((update.numel() ** 0.5) / clipping_threshold)).clamp_(max=1.0) |
| 242 | + update.div_(denom) |
| 243 | + |
| 244 | + # Apply momentum (in different dtype) |
| 245 | + if momentum is not None and exp_avg is not None: |
| 246 | + if momentum_dtype != grad.dtype: |
| 247 | + exp_avg.lerp_(update.to(momentum_dtype), 1 - momentum) # ema |
| 248 | + update = exp_avg.to(grad.dtype) |
| 249 | + else: |
| 250 | + exp_avg.lerp_(update, 1 - momentum) # ema |
| 251 | + update = exp_avg.clone() |
| 252 | + |
| 253 | + # Scale by learning rate |
| 254 | + update.mul_(lr) |
| 255 | + |
| 256 | + # Perform weight decay |
| 257 | + if weight_decay != 0: |
| 258 | + if unscaled_wd: |
| 259 | + # match big vision impl, 'fully decoupled' decay w/o LR scaling |
| 260 | + param.mul_(1. - weight_decay) |
| 261 | + else: |
| 262 | + # match typical pytorch behaviour for decoupled decay, eg adamw where wd is scaled by LR |
| 263 | + param.mul_(1. - lr * weight_decay) |
| 264 | + |
| 265 | + # Update parameters |
| 266 | + param.add_(update, alpha=-1.0) |
| 267 | + |
| 268 | +def _multi_tensor_adafactor( |
| 269 | + params: List[Tensor], |
| 270 | + grads: List[Tensor], |
| 271 | + exp_avg_sq_rs: List[Optional[Tensor]], |
| 272 | + exp_avg_sq_cs: List[Optional[Tensor]], |
| 273 | + exp_avg_sqs: List[Optional[Tensor]], |
| 274 | + exp_avgs: List[Optional[Tensor]], |
| 275 | + state_steps: List[Tensor], |
| 276 | + *, |
| 277 | + beta2_decay: float, |
| 278 | + beta2_cap: float, |
| 279 | + min_dim_size_to_factor: int, |
| 280 | + eps: float, |
| 281 | + lr: float, |
| 282 | + weight_decay: float, |
| 283 | + momentum: Optional[float], |
| 284 | + momentum_dtype: Union[str, torch.dtype], |
| 285 | + clipping_threshold: Optional[float], |
| 286 | + unscaled_wd: bool, |
| 287 | +): |
| 288 | + assert False, 'multi-tensor fn (foreach=True) not implemented yet' |
0 commit comments