6
6
* https://github.com/ClashLuke/HeavyBall (added improvements)
7
7
8
8
"""
9
+ import logging
9
10
import string
10
11
import random
12
+ import warnings
13
+ from typing import Any , Callable , Dict , Optional , Tuple , Union
11
14
12
15
import numpy as np
13
16
import torch
17
+
14
18
try :
15
19
# NOTE opt_einsum needed to avoid blowing up memory with einsum ops
16
20
import opt_einsum
17
21
opt_einsum .enabled = True
18
22
opt_einsum .strategy = "auto-hq"
19
23
import torch .backends .opt_einsum
24
+ has_opt_einsum = True
20
25
except ImportError :
21
- opt_einsum = None
26
+ has_opt_einsum = False
22
27
23
28
try :
24
29
torch ._dynamo .config .cache_size_limit = 1_000_000
25
30
has_dynamo = True
26
31
except AttributeError :
27
32
has_dynamo = False
28
33
34
+ _logger = logging .getLogger (__name__ )
35
+
29
36
30
37
def precond_update_prob_schedule (
31
38
n : float ,
32
39
max_prob : float = 1.0 ,
33
40
min_prob : float = 0.03 ,
34
41
decay : float = 0.001 ,
35
42
flat_start : float = 500 ,
36
- ):
43
+ ) -> torch . Tensor :
37
44
"""Anneal preconditioner update probability during beginning of training.
38
45
39
46
PSGD benefits from more preconditioner updates at the beginning of training,
@@ -57,38 +64,43 @@ class Kron(torch.optim.Optimizer):
57
64
"""Implements PSGD Kron from https://github.com/lixilinx/psgd_torch.
58
65
59
66
Args:
60
- params (iterable): Iterable of parameters to optimize or dicts defining parameter groups.
61
- lr (float): Learning rate.
62
- momentum (float): Momentum parameter.
63
- weight_decay (float): Weight decay (L2 penalty).
64
- preconditioner_update_probability (callable or float, optional): Probability of
65
- updating the preconditioner. If None, defaults to a schedule that anneals
66
- from 1.0 to 0.03 by 4000 steps.
67
- max_size_triangular (int): Max size for dim's preconditioner to be triangular.
68
- min_ndim_triangular (int): Minimum number of dimensions a layer needs to have triangular preconditioners.
69
- memory_save_mode: (string, optional), None, 'one_diag', or 'all_diag', None is default
67
+ params: Iterable of parameters to optimize or dicts defining parameter groups.
68
+ lr: Learning rate.
69
+ momentum: Momentum parameter.
70
+ weight_decay: Weight decay (L2 penalty).
71
+ preconditioner_update_probability: Probability of updating the preconditioner.
72
+ If None, defaults to a schedule that anneals from 1.0 to 0.03 by 4000 steps.
73
+ max_size_triangular: Max size for dim's preconditioner to be triangular.
74
+ min_ndim_triangular: Minimum number of dimensions a layer needs to have triangular preconditioners.
75
+ memory_save_mode: 'one_diag', or 'all_diag', None is default
70
76
to set all preconditioners to be triangular, 'one_diag' sets the largest
71
77
or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners to be diagonal.
72
- momentum_into_precond_update: (bool), whether to send momentum into preconditioner
78
+ momentum_into_precond_update: whether to send momentum into preconditioner
73
79
update instead of raw gradients.
74
- mu_dtype (torch.dtype, optional): Dtype of the momentum accumulator.
75
- precond_dtype (torch.dtype, optional): Dtype of the preconditioner.
80
+ mu_dtype: Dtype of the momentum accumulator.
81
+ precond_dtype: Dtype of the preconditioner.
82
+ decoupled_decay: AdamW style decoupled-decay.
83
+ deterministic: Deterministic behaviour across save / load (resume). FIXME slow, needs work
76
84
"""
77
85
78
86
def __init__ (
79
87
self ,
80
88
params ,
81
- lr = 0.001 ,
82
- momentum = 0.9 ,
83
- weight_decay = 0.0 ,
84
- preconditioner_update_probability = None ,
85
- max_size_triangular = 2048 ,
86
- min_ndim_triangular = 2 ,
87
- memory_save_mode = None ,
88
- momentum_into_precond_update = True ,
89
- mu_dtype = None ,
90
- precond_dtype = None ,
89
+ lr : float = 0.001 ,
90
+ momentum : float = 0.9 ,
91
+ weight_decay : float = 0.0 ,
92
+ preconditioner_update_probability : Optional [Union [Callable , float ]] = None ,
93
+ max_size_triangular : int = 2048 ,
94
+ min_ndim_triangular : int = 2 ,
95
+ memory_save_mode : Optional [str ] = None ,
96
+ momentum_into_precond_update : bool = True ,
97
+ mu_dtype : Optional [torch .dtype ] = None ,
98
+ precond_dtype : Optional [torch .dtype ] = None ,
99
+ decoupled_decay : bool = False ,
100
+ deterministic : bool = False ,
91
101
):
102
+ if not has_opt_einsum :
103
+ warnings .warn ("It is highly recommended to have 'opt_einsum' installed for this optimizer." )
92
104
if not 0.0 <= lr :
93
105
raise ValueError (f"Invalid learning rate: { lr } " )
94
106
if not 0.0 <= momentum < 1.0 :
@@ -109,13 +121,18 @@ def __init__(
109
121
precond_init_scale = 1.0 , # precond init scale hardcoded to 1.0
110
122
mu_dtype = mu_dtype ,
111
123
precond_dtype = precond_dtype ,
124
+ decoupled_decay = decoupled_decay ,
112
125
)
113
126
super (Kron , self ).__init__ (params , defaults )
114
127
128
+ self ._param_exprs = {}
115
129
self ._tiny = torch .finfo (torch .bfloat16 ).tiny
116
- self ._prob_step = 0
117
- self ._update_counter = 0
118
- self .rng = random .Random (5318008 )
130
+ self .rng = random .Random (1337 )
131
+ if deterministic :
132
+ # Use a Generator to try to be more deterministic across resume (save/load)
133
+ self .torch_rng = torch .Generator ().manual_seed (1337 )
134
+ else :
135
+ self .torch_rng = None
119
136
120
137
# make compile optional (for bwd compat)
121
138
if has_dynamo :
@@ -129,6 +146,44 @@ def __init__(
129
146
self ._precond_grad = _precond_grad
130
147
self ._balance_Q = _balance_Q
131
148
149
+ def __getstate__ (self ):
150
+ _dict = super ().__getstate__ ()
151
+ _dict ["rng" ] = self .rng
152
+ _dict ["torch_rng" ] = self .torch_rng
153
+ return _dict
154
+
155
+ def state_dict (self ) -> Dict [str , Any ]:
156
+ # Get the optimizer's state dict
157
+ optimizer_state = super ().state_dict ()
158
+
159
+ # Add the generator state
160
+ optimizer_state ['rng_state' ] = self .rng .getstate ()
161
+ if self .torch_rng is not None :
162
+ optimizer_state ['torch_rng_state' ] = self .torch_rng .get_state ()
163
+
164
+ return optimizer_state
165
+
166
+ def load_state_dict (self , state_dict : Dict [str , Any ]) -> None :
167
+ # Extract and remove the RNG state from the state dict
168
+ rng_state = state_dict .pop ('rng_state' , None )
169
+ torch_rng_state = state_dict .pop ('torch_rng_state' , None )
170
+
171
+ # Load the optimizer state
172
+ super ().load_state_dict (state_dict )
173
+
174
+ # Restore the RNG state if it exists
175
+ if rng_state is not None :
176
+ self .rng .setstate (rng_state )
177
+ state_dict ['rng_state' ] = rng_state # put it back if caller still using state_dict
178
+ if torch_rng_state is not None :
179
+ if self .torch_rng is not None :
180
+ self .torch_rng .set_state (torch_rng_state )
181
+ state_dict ['torch_rng_state' ] = torch_rng_state # put it back if caller still using state_dict
182
+
183
+ def __setstate__ (self , state ):
184
+ super ().__setstate__ (state )
185
+ self ._param_exprs = {}
186
+
132
187
@torch .no_grad ()
133
188
def step (self , closure = None ):
134
189
loss = None
@@ -141,25 +196,11 @@ def step(self, closure=None):
141
196
total_precond_size = 0
142
197
total_precond_mb = 0
143
198
144
- # update preconditioners all together deterministically
145
- update_prob = self .param_groups [0 ]["preconditioner_update_probability" ]
146
- if update_prob is None :
147
- update_prob = precond_update_prob_schedule
148
- if callable (update_prob ):
149
- update_prob = update_prob (self ._prob_step )
150
- self ._update_counter += 1
151
- do_update = self ._update_counter >= 1 / update_prob
152
- if do_update :
153
- self ._update_counter = 0
154
- self ._prob_step += 1
155
-
156
- # balance preconditioners roughly every 100 updates
157
- balance = self .rng .random () < 0.01 and do_update
158
-
159
199
for group in self .param_groups :
160
200
mu_dtype = group .get ("mu_dtype" )
161
201
precond_dtype = group .get ("precond_dtype" , torch .float32 )
162
202
momentum_into_precond_update = group .get ("momentum_into_precond_update" , True )
203
+ update_prob = group .get ("preconditioner_update_probability" , None )
163
204
164
205
for p in group ["params" ]:
165
206
if p .grad is None :
@@ -170,17 +211,19 @@ def step(self, closure=None):
170
211
171
212
if len (state ) == 0 :
172
213
state ["step" ] = 0
214
+ state ["update_counter" ] = 0
173
215
state ["momentum_buffer" ] = torch .zeros_like (p , dtype = mu_dtype or p .dtype )
174
- state ["Q" ], state [ " exprs" ] = _init_Q_exprs (
216
+ state ["Q" ], exprs = _init_Q_exprs (
175
217
p ,
176
218
group ["precond_init_scale" ],
177
219
group ["max_size_triangular" ],
178
220
group ["min_ndim_triangular" ],
179
221
group ["memory_save_mode" ],
180
222
dtype = precond_dtype ,
181
223
)
224
+ self ._param_exprs [p ] = exprs
182
225
183
- # Print sizes
226
+ # Accumulate sizes for log
184
227
momentum_size = state ["momentum_buffer" ].numel ()
185
228
momentum_mb = momentum_size * state ["momentum_buffer" ].element_size () / 2 ** 20
186
229
total_momentum_size += momentum_size
@@ -190,6 +233,29 @@ def step(self, closure=None):
190
233
precond_mb = sum (q .numel () * q .element_size () for q in state ["Q" ]) / 2 ** 20
191
234
total_precond_size += precond_size
192
235
total_precond_mb += precond_mb
236
+ elif p not in self ._param_exprs :
237
+ exprs = _init_Q_exprs (
238
+ p ,
239
+ group ["precond_init_scale" ],
240
+ group ["max_size_triangular" ],
241
+ group ["min_ndim_triangular" ],
242
+ group ["memory_save_mode" ],
243
+ dtype = precond_dtype ,
244
+ init_q = False ,
245
+ )
246
+ self ._param_exprs [p ] = exprs
247
+ else :
248
+ exprs = self ._param_exprs [p ]
249
+
250
+ # update preconditioners all together deterministically
251
+ if update_prob is None :
252
+ update_prob = precond_update_prob_schedule
253
+ if callable (update_prob ):
254
+ update_prob = update_prob (state ["step" ])
255
+ state ["update_counter" ] += 1
256
+ do_update = state ["update_counter" ] >= 1 / update_prob
257
+ if do_update :
258
+ state ["update_counter" ] = 0
193
259
194
260
state ["step" ] += 1
195
261
@@ -198,21 +264,30 @@ def step(self, closure=None):
198
264
bias_correction = 1 - beta ** state ["step" ]
199
265
momentum_buffer = state ["momentum_buffer" ]
200
266
momentum_buffer .mul_ (group ["momentum" ]).add_ (grad , alpha = 1 - group ["momentum" ])
267
+
201
268
# Restore momentum dtype
202
269
if mu_dtype is not None :
203
- momentum_buffer .copy_ (momentum_buffer .to (dtype = mu_dtype , non_blocking = True ))
204
- debiased_momentum = momentum_buffer / bias_correction
205
- debiased_momentum = debiased_momentum .to (dtype = precond_dtype , non_blocking = True )
270
+ momentum_buffer .copy_ (momentum_buffer .to (dtype = mu_dtype ))
271
+ debiased_momentum = (momentum_buffer / bias_correction ).to (dtype = precond_dtype )
206
272
207
- # balance preconditioners about every 100 updates
273
+ # Balance preconditioners roughly every 100 updates
274
+ balance = self .rng .random () < 0.01 and do_update
208
275
if grad .dim () > 1 and balance :
209
276
self ._balance_Q (state ["Q" ])
210
277
211
278
# Update preconditioner
212
279
if do_update :
213
- exprA , exprGs , _ = state [ " exprs" ]
280
+ exprA , exprGs , _ = exprs
214
281
Q = state ["Q" ]
215
- V = torch .randn_like (debiased_momentum , dtype = precond_dtype )
282
+ if self .torch_rng is None :
283
+ V = torch .randn_like (debiased_momentum , dtype = precond_dtype )
284
+ else :
285
+ # Restoring generator state to device is messy. For now,
286
+ # we keep RNG on CPU, but this slows the optimizer down quite a bit.
287
+ # FIXME Need a better approach
288
+ V = torch .randn (
289
+ debiased_momentum .shape , generator = self .torch_rng , dtype = precond_dtype , device = 'cpu' )
290
+ V = V .to (debiased_momentum .device )
216
291
G = debiased_momentum if momentum_into_precond_update else grad
217
292
218
293
A , conjB = self ._calc_A_and_conjB (exprA , G , Q , V )
@@ -225,54 +300,59 @@ def step(self, closure=None):
225
300
if q .dim () < 2 :
226
301
tmp *= q
227
302
tmp /= (term1 + term2 ).norm (float ("inf" )) + self ._tiny
228
- q .sub_ (tmp )
229
303
else :
230
304
tmp = torch .triu (tmp )
231
305
tmp /= _norm_lower_bound (term1 + term2 ) + self ._tiny
232
306
tmp @= q
233
- q .sub_ (tmp )
234
-
235
- # _update_precond(
236
- # state["Q"],
237
- # state["exprs"],
238
- # torch.randn_like(debiased_momentum, dtype=precond_dtype),
239
- # debiased_momentum if momentum_into_precond_update else grad,
240
- # group["precond_lr"],
241
- # self._tiny,
242
- # )
307
+ q .sub_ (tmp )
243
308
244
309
# Precondition gradients
245
310
pre_grad = self ._precond_grad (
246
311
state ["Q" ],
247
- state [ " exprs" ] ,
312
+ exprs ,
248
313
debiased_momentum ,
249
- ).to (dtype = p .dtype , non_blocking = True )
314
+ ).to (dtype = p .dtype )
250
315
251
316
# RMS of pre_grad should be 1.0, so let's cap at 1.1
252
- pre_grad .mul_ (torch .clamp (1.1 / (pre_grad .square ().mean ().sqrt () + 1e-6 ), max = 1.0 ))
317
+ pre_grad .mul_ (torch .clamp (1.1 / (pre_grad .square ().mean ().sqrt_ () + 1e-8 ), max = 1.0 ))
253
318
254
- # Apply weight decay and update parameters
255
- if group ["weight_decay" ] != 0 and p .dim () >= 2 :
256
- pre_grad .add_ (p , alpha = group ["weight_decay" ])
319
+ # Apply weight decay
320
+ if group ["weight_decay" ] != 0 :
321
+ if group ["decoupled_decay" ]:
322
+ p .mul_ (1. - group ["lr" ] * group ["weight_decay" ])
323
+ else :
324
+ pre_grad .add_ (p , alpha = group ["weight_decay" ])
325
+
326
+ # Update parameters
257
327
p .add_ (pre_grad , alpha = - group ["lr" ])
258
328
259
329
if total_momentum_size > 0 :
260
- print (f"PSGD Momentum buffer size: { total_momentum_size } elements, { total_momentum_mb :.2f} MB" )
261
- print (f"PSGD Preconditioners size: { total_precond_size } elements, { total_precond_mb :.2f} MB" )
330
+ _logger . info (f"PSGD Momentum buffer size: { total_momentum_size } elements, { total_momentum_mb :.2f} MB" )
331
+ _logger . info (f"PSGD Preconditioners size: { total_precond_size } elements, { total_precond_mb :.2f} MB" )
262
332
263
333
return loss
264
334
265
335
266
- def _init_Q_exprs (t , scale , max_size , min_ndim_triangular , memory_save_mode , dtype = None ):
336
+ def _init_Q_exprs (
337
+ t ,
338
+ scale ,
339
+ max_size ,
340
+ min_ndim_triangular ,
341
+ memory_save_mode ,
342
+ dtype = None ,
343
+ init_q = True ,
344
+ ):
267
345
"""For a scalar or tensor t, we initialize its preconditioner Q and
268
346
reusable einsum expressions for updating Q and preconditioning gradient.
269
347
"""
270
348
letters = string .ascii_lowercase + string .ascii_uppercase
271
349
272
350
dtype = dtype if dtype is not None else t .dtype
273
351
shape = t .shape
352
+ Q = []
274
353
if len (shape ) == 0 : # scalar
275
- Q = [scale * torch .ones_like (t , dtype = dtype )]
354
+ if init_q :
355
+ Q .append (scale * torch .ones_like (t , dtype = dtype ))
276
356
exprA = ",->"
277
357
exprGs = [",->" ]
278
358
exprP = ",,->"
@@ -288,13 +368,18 @@ def _init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dty
288
368
rev_sorted_dims = np .argsort (shape )[::- 1 ]
289
369
dim_diag = [False for _ in shape ]
290
370
dim_diag [rev_sorted_dims [0 ]] = True
371
+ elif memory_save_mode == "smart_one_diag" :
372
+ dim_diag = [False for _ in shape ]
373
+ rev_sorted_dims = np .argsort (shape )[::- 1 ]
374
+ sorted_shape = sorted (shape )
375
+ if len (shape ) >= 2 and sorted_shape [- 1 ] > sorted_shape [- 2 ]:
376
+ dim_diag [rev_sorted_dims [0 ]] = True
291
377
elif memory_save_mode == "all_diag" :
292
378
dim_diag = [True for _ in shape ]
293
379
else :
294
380
raise ValueError (
295
381
f"Invalid memory_save_mode: { memory_save_mode } , must be one of [None, 'one_diag', 'all_diag']" )
296
382
297
- Q = []
298
383
piece1A , piece2A , piece3A = ([], "" , "" )
299
384
exprGs = []
300
385
piece1P , piece2P , piece3P , piece4P = ([], [], "" , "" )
@@ -306,7 +391,8 @@ def _init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dty
306
391
or dim_d
307
392
):
308
393
# use diagonal matrix as preconditioner for this dim
309
- Q .append (scale * torch .ones (size , dtype = dtype , device = t .device ))
394
+ if init_q :
395
+ Q .append (scale * torch .ones (size , dtype = dtype , device = t .device ))
310
396
311
397
piece1A .append (letters [i ])
312
398
piece2A = piece2A + letters [i ]
@@ -322,7 +408,8 @@ def _init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dty
322
408
piece4P = piece4P + letters [i + 13 ]
323
409
else :
324
410
# use triangular matrix as preconditioner for this dim
325
- Q .append (scale * torch .eye (size , dtype = dtype , device = t .device ))
411
+ if init_q :
412
+ Q .append (scale * torch .eye (size , dtype = dtype , device = t .device ))
326
413
327
414
piece1A .append (letters [i ] + letters [i + 13 ])
328
415
piece2A = piece2A + letters [i + 13 ]
@@ -343,7 +430,10 @@ def _init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dty
343
430
exprP = "," .join (piece1P ) + "," + "," .join (piece2P ) + "," + piece3P + "->" + piece4P
344
431
345
432
exprGs = tuple (exprGs )
346
- return [Q , (exprA , exprGs , exprP )]
433
+ if init_q :
434
+ return [Q , (exprA , exprGs , exprP )]
435
+ else :
436
+ return exprA , exprGs , exprP
347
437
348
438
349
439
def _lb (A , max_abs ):
@@ -368,10 +458,10 @@ def _norm_lower_bound(A):
368
458
def _solve_triangular_right (X , A ):
369
459
"""X @ inv(A)"""
370
460
orig_dtype = X .dtype
371
- X = X .to (dtype = torch .float32 , non_blocking = True )
372
- A = A .to (dtype = torch .float32 , non_blocking = True )
461
+ X = X .to (dtype = torch .float32 )
462
+ A = A .to (dtype = torch .float32 )
373
463
out = torch .linalg .solve_triangular (A , X .reshape (- 1 , X .size (- 1 )), upper = True , left = False ).reshape_as (X )
374
- return out .to (dtype = orig_dtype , non_blocking = True )
464
+ return out .to (dtype = orig_dtype )
375
465
376
466
377
467
def _balance_Q (Q_in ):
0 commit comments