Skip to content

Commit 71d1741

Browse files
committed
More additions to Kron
1 parent f759d12 commit 71d1741

File tree

1 file changed

+37
-27
lines changed

1 file changed

+37
-27
lines changed

timm/optim/kron.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,6 @@
1515
import numpy as np
1616
import torch
1717

18-
try:
19-
# NOTE opt_einsum needed to avoid blowing up memory with einsum ops
20-
import opt_einsum
21-
opt_einsum.enabled = True
22-
opt_einsum.strategy = "auto-hq"
23-
import torch.backends.opt_einsum
24-
has_opt_einsum = True
25-
except ImportError:
26-
has_opt_einsum = False
2718

2819
try:
2920
torch._dynamo.config.cache_size_limit = 1_000_000
@@ -67,19 +58,20 @@ class Kron(torch.optim.Optimizer):
6758
params: Iterable of parameters to optimize or dicts defining parameter groups.
6859
lr: Learning rate.
6960
momentum: Momentum parameter.
70-
weight_decay: Weight decay (L2 penalty).
61+
weight_decay: Weight decay.
7162
preconditioner_update_probability: Probability of updating the preconditioner.
7263
If None, defaults to a schedule that anneals from 1.0 to 0.03 by 4000 steps.
7364
max_size_triangular: Max size for dim's preconditioner to be triangular.
7465
min_ndim_triangular: Minimum number of dimensions a layer needs to have triangular preconditioners.
75-
memory_save_mode: 'one_diag', or 'all_diag', None is default
66+
memory_save_mode: 'one_diag', 'smart_one_diag', or 'all_diag', None is default
7667
to set all preconditioners to be triangular, 'one_diag' sets the largest
7768
or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners to be diagonal.
7869
momentum_into_precond_update: whether to send momentum into preconditioner
7970
update instead of raw gradients.
8071
mu_dtype: Dtype of the momentum accumulator.
8172
precond_dtype: Dtype of the preconditioner.
82-
decoupled_decay: AdamW style decoupled-decay.
73+
decoupled_decay: AdamW style decoupled weight decay
74+
flatten_dim: Flatten dim >= 2 instead of relying on expressions
8375
deterministic: Deterministic behaviour across save / load (resume). FIXME slow, needs work
8476
"""
8577

@@ -97,10 +89,18 @@ def __init__(
9789
mu_dtype: Optional[torch.dtype] = None,
9890
precond_dtype: Optional[torch.dtype] = None,
9991
decoupled_decay: bool = False,
92+
flatten_dim: bool = False,
10093
deterministic: bool = False,
10194
):
102-
if not has_opt_einsum:
95+
try:
96+
# NOTE opt_einsum needed to avoid blowing up memory with einsum ops
97+
import opt_einsum
98+
opt_einsum.enabled = True
99+
opt_einsum.strategy = "auto-hq"
100+
import torch.backends.opt_einsum
101+
except ImportError:
103102
warnings.warn("It is highly recommended to have 'opt_einsum' installed for this optimizer." )
103+
104104
if not 0.0 <= lr:
105105
raise ValueError(f"Invalid learning rate: {lr}")
106106
if not 0.0 <= momentum < 1.0:
@@ -122,10 +122,11 @@ def __init__(
122122
mu_dtype=mu_dtype,
123123
precond_dtype=precond_dtype,
124124
decoupled_decay=decoupled_decay,
125+
flatten_dim=flatten_dim,
125126
)
126127
super(Kron, self).__init__(params, defaults)
127128

128-
self._param_exprs = {}
129+
self._param_exprs = {} # cache for einsum expr
129130
self._tiny = torch.finfo(torch.bfloat16).tiny
130131
self.rng = random.Random(1337)
131132
if deterministic:
@@ -165,20 +166,21 @@ def state_dict(self) -> Dict[str, Any]:
165166

166167
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
167168
# Extract and remove the RNG state from the state dict
168-
rng_state = state_dict.pop('rng_state', None)
169-
torch_rng_state = state_dict.pop('torch_rng_state', None)
169+
rng_states = {}
170+
if 'rng_state' in state_dict:
171+
rng_states['rng_state'] = state_dict.pop('rng_state')
172+
if 'torch_rng_state' in state_dict:
173+
rng_states['torch_rng_state'] = state_dict.pop('torch_rng_state')
170174

171175
# Load the optimizer state
172176
super().load_state_dict(state_dict)
177+
state_dict.update(rng_states) # add back
173178

174179
# Restore the RNG state if it exists
175-
if rng_state is not None:
176-
self.rng.setstate(rng_state)
177-
state_dict['rng_state'] = rng_state # put it back if caller still using state_dict
178-
if torch_rng_state is not None:
179-
if self.torch_rng is not None:
180-
self.torch_rng.set_state(torch_rng_state)
181-
state_dict['torch_rng_state'] = torch_rng_state # put it back if caller still using state_dict
180+
if 'rng_state' in rng_states:
181+
self.rng.setstate(rng_states['rng_state'])
182+
if 'torch_rng_state' in rng_states:
183+
self.torch_rng.set_state(rng_states['torch_rng_state'])
182184

183185
def __setstate__(self, state):
184186
super().__setstate__(state)
@@ -208,13 +210,16 @@ def step(self, closure=None):
208210

209211
grad = p.grad
210212
state = self.state[p]
213+
if group['flatten_dim']:
214+
grad = grad.view(grad.size(0), -1)
211215

212216
if len(state) == 0:
213217
state["step"] = 0
214218
state["update_counter"] = 0
215-
state["momentum_buffer"] = torch.zeros_like(p, dtype=mu_dtype or p.dtype)
219+
state["momentum_buffer"] = torch.zeros_like(grad, dtype=mu_dtype or grad.dtype)
220+
# init Q and einsum expressions on first step
216221
state["Q"], exprs = _init_Q_exprs(
217-
p,
222+
grad,
218223
group["precond_init_scale"],
219224
group["max_size_triangular"],
220225
group["min_ndim_triangular"],
@@ -234,8 +239,9 @@ def step(self, closure=None):
234239
total_precond_size += precond_size
235240
total_precond_mb += precond_mb
236241
elif p not in self._param_exprs:
242+
# init only the einsum expressions, called after state load, Q are loaded from state_dict
237243
exprs = _init_Q_exprs(
238-
p,
244+
grad,
239245
group["precond_init_scale"],
240246
group["max_size_triangular"],
241247
group["min_ndim_triangular"],
@@ -245,6 +251,7 @@ def step(self, closure=None):
245251
)
246252
self._param_exprs[p] = exprs
247253
else:
254+
# retrieve cached expressions
248255
exprs = self._param_exprs[p]
249256

250257
# update preconditioners all together deterministically
@@ -315,6 +322,8 @@ def step(self, closure=None):
315322

316323
# RMS of pre_grad should be 1.0, so let's cap at 1.1
317324
pre_grad.mul_(torch.clamp(1.1 / (pre_grad.square().mean().sqrt_() + 1e-8), max=1.0))
325+
if group['flatten_dim']:
326+
pre_grad = pre_grad.view(p.shape)
318327

319328
# Apply weight decay
320329
if group["weight_decay"] != 0:
@@ -369,9 +378,10 @@ def _init_Q_exprs(
369378
dim_diag = [False for _ in shape]
370379
dim_diag[rev_sorted_dims[0]] = True
371380
elif memory_save_mode == "smart_one_diag":
372-
dim_diag = [False for _ in shape]
381+
# addition proposed by Lucas Nestler
373382
rev_sorted_dims = np.argsort(shape)[::-1]
374383
sorted_shape = sorted(shape)
384+
dim_diag = [False for _ in shape]
375385
if len(shape) >= 2 and sorted_shape[-1] > sorted_shape[-2]:
376386
dim_diag[rev_sorted_dims[0]] = True
377387
elif memory_save_mode == "all_diag":

0 commit comments

Comments
 (0)