15
15
import numpy as np
16
16
import torch
17
17
18
- try :
19
- # NOTE opt_einsum needed to avoid blowing up memory with einsum ops
20
- import opt_einsum
21
- opt_einsum .enabled = True
22
- opt_einsum .strategy = "auto-hq"
23
- import torch .backends .opt_einsum
24
- has_opt_einsum = True
25
- except ImportError :
26
- has_opt_einsum = False
27
18
28
19
try :
29
20
torch ._dynamo .config .cache_size_limit = 1_000_000
@@ -67,19 +58,20 @@ class Kron(torch.optim.Optimizer):
67
58
params: Iterable of parameters to optimize or dicts defining parameter groups.
68
59
lr: Learning rate.
69
60
momentum: Momentum parameter.
70
- weight_decay: Weight decay (L2 penalty) .
61
+ weight_decay: Weight decay.
71
62
preconditioner_update_probability: Probability of updating the preconditioner.
72
63
If None, defaults to a schedule that anneals from 1.0 to 0.03 by 4000 steps.
73
64
max_size_triangular: Max size for dim's preconditioner to be triangular.
74
65
min_ndim_triangular: Minimum number of dimensions a layer needs to have triangular preconditioners.
75
- memory_save_mode: 'one_diag', or 'all_diag', None is default
66
+ memory_save_mode: 'one_diag', 'smart_one_diag', or 'all_diag', None is default
76
67
to set all preconditioners to be triangular, 'one_diag' sets the largest
77
68
or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners to be diagonal.
78
69
momentum_into_precond_update: whether to send momentum into preconditioner
79
70
update instead of raw gradients.
80
71
mu_dtype: Dtype of the momentum accumulator.
81
72
precond_dtype: Dtype of the preconditioner.
82
- decoupled_decay: AdamW style decoupled-decay.
73
+ decoupled_decay: AdamW style decoupled weight decay
74
+ flatten_dim: Flatten dim >= 2 instead of relying on expressions
83
75
deterministic: Deterministic behaviour across save / load (resume). FIXME slow, needs work
84
76
"""
85
77
@@ -97,10 +89,18 @@ def __init__(
97
89
mu_dtype : Optional [torch .dtype ] = None ,
98
90
precond_dtype : Optional [torch .dtype ] = None ,
99
91
decoupled_decay : bool = False ,
92
+ flatten_dim : bool = False ,
100
93
deterministic : bool = False ,
101
94
):
102
- if not has_opt_einsum :
95
+ try :
96
+ # NOTE opt_einsum needed to avoid blowing up memory with einsum ops
97
+ import opt_einsum
98
+ opt_einsum .enabled = True
99
+ opt_einsum .strategy = "auto-hq"
100
+ import torch .backends .opt_einsum
101
+ except ImportError :
103
102
warnings .warn ("It is highly recommended to have 'opt_einsum' installed for this optimizer." )
103
+
104
104
if not 0.0 <= lr :
105
105
raise ValueError (f"Invalid learning rate: { lr } " )
106
106
if not 0.0 <= momentum < 1.0 :
@@ -122,10 +122,11 @@ def __init__(
122
122
mu_dtype = mu_dtype ,
123
123
precond_dtype = precond_dtype ,
124
124
decoupled_decay = decoupled_decay ,
125
+ flatten_dim = flatten_dim ,
125
126
)
126
127
super (Kron , self ).__init__ (params , defaults )
127
128
128
- self ._param_exprs = {}
129
+ self ._param_exprs = {} # cache for einsum expr
129
130
self ._tiny = torch .finfo (torch .bfloat16 ).tiny
130
131
self .rng = random .Random (1337 )
131
132
if deterministic :
@@ -165,20 +166,21 @@ def state_dict(self) -> Dict[str, Any]:
165
166
166
167
def load_state_dict (self , state_dict : Dict [str , Any ]) -> None :
167
168
# Extract and remove the RNG state from the state dict
168
- rng_state = state_dict .pop ('rng_state' , None )
169
- torch_rng_state = state_dict .pop ('torch_rng_state' , None )
169
+ rng_states = {}
170
+ if 'rng_state' in state_dict :
171
+ rng_states ['rng_state' ] = state_dict .pop ('rng_state' )
172
+ if 'torch_rng_state' in state_dict :
173
+ rng_states ['torch_rng_state' ] = state_dict .pop ('torch_rng_state' )
170
174
171
175
# Load the optimizer state
172
176
super ().load_state_dict (state_dict )
177
+ state_dict .update (rng_states ) # add back
173
178
174
179
# Restore the RNG state if it exists
175
- if rng_state is not None :
176
- self .rng .setstate (rng_state )
177
- state_dict ['rng_state' ] = rng_state # put it back if caller still using state_dict
178
- if torch_rng_state is not None :
179
- if self .torch_rng is not None :
180
- self .torch_rng .set_state (torch_rng_state )
181
- state_dict ['torch_rng_state' ] = torch_rng_state # put it back if caller still using state_dict
180
+ if 'rng_state' in rng_states :
181
+ self .rng .setstate (rng_states ['rng_state' ])
182
+ if 'torch_rng_state' in rng_states :
183
+ self .torch_rng .set_state (rng_states ['torch_rng_state' ])
182
184
183
185
def __setstate__ (self , state ):
184
186
super ().__setstate__ (state )
@@ -208,13 +210,16 @@ def step(self, closure=None):
208
210
209
211
grad = p .grad
210
212
state = self .state [p ]
213
+ if group ['flatten_dim' ]:
214
+ grad = grad .view (grad .size (0 ), - 1 )
211
215
212
216
if len (state ) == 0 :
213
217
state ["step" ] = 0
214
218
state ["update_counter" ] = 0
215
- state ["momentum_buffer" ] = torch .zeros_like (p , dtype = mu_dtype or p .dtype )
219
+ state ["momentum_buffer" ] = torch .zeros_like (grad , dtype = mu_dtype or grad .dtype )
220
+ # init Q and einsum expressions on first step
216
221
state ["Q" ], exprs = _init_Q_exprs (
217
- p ,
222
+ grad ,
218
223
group ["precond_init_scale" ],
219
224
group ["max_size_triangular" ],
220
225
group ["min_ndim_triangular" ],
@@ -234,8 +239,9 @@ def step(self, closure=None):
234
239
total_precond_size += precond_size
235
240
total_precond_mb += precond_mb
236
241
elif p not in self ._param_exprs :
242
+ # init only the einsum expressions, called after state load, Q are loaded from state_dict
237
243
exprs = _init_Q_exprs (
238
- p ,
244
+ grad ,
239
245
group ["precond_init_scale" ],
240
246
group ["max_size_triangular" ],
241
247
group ["min_ndim_triangular" ],
@@ -245,6 +251,7 @@ def step(self, closure=None):
245
251
)
246
252
self ._param_exprs [p ] = exprs
247
253
else :
254
+ # retrieve cached expressions
248
255
exprs = self ._param_exprs [p ]
249
256
250
257
# update preconditioners all together deterministically
@@ -315,6 +322,8 @@ def step(self, closure=None):
315
322
316
323
# RMS of pre_grad should be 1.0, so let's cap at 1.1
317
324
pre_grad .mul_ (torch .clamp (1.1 / (pre_grad .square ().mean ().sqrt_ () + 1e-8 ), max = 1.0 ))
325
+ if group ['flatten_dim' ]:
326
+ pre_grad = pre_grad .view (p .shape )
318
327
319
328
# Apply weight decay
320
329
if group ["weight_decay" ] != 0 :
@@ -369,9 +378,10 @@ def _init_Q_exprs(
369
378
dim_diag = [False for _ in shape ]
370
379
dim_diag [rev_sorted_dims [0 ]] = True
371
380
elif memory_save_mode == "smart_one_diag" :
372
- dim_diag = [ False for _ in shape ]
381
+ # addition proposed by Lucas Nestler
373
382
rev_sorted_dims = np .argsort (shape )[::- 1 ]
374
383
sorted_shape = sorted (shape )
384
+ dim_diag = [False for _ in shape ]
375
385
if len (shape ) >= 2 and sorted_shape [- 1 ] > sorted_shape [- 2 ]:
376
386
dim_diag [rev_sorted_dims [0 ]] = True
377
387
elif memory_save_mode == "all_diag" :
0 commit comments