1
- """ PyTorch Implementation of the Kron PSGD optimizer
1
+ """ PyTorch Implementation of the Kron ( PSGD) optimizer
2
2
3
- FIXME attribution
4
- * https://github.com/evanatyourservice/kron_torch (direct source)
5
- * https://github.com/lixilinx/psgd_torch (original)
6
- * https://github.com/ClashLuke/HeavyBall (added improvements)
3
+ This is a PSGD optimizer using a Kronecker-factored preconditioner.
4
+
5
+ This impl was adapted from https://github.com/evanatyourservice/kron_torch
6
+ by Evan Walters, licensed CC-BY-4.0.
7
+
8
+ Improvements, fixes to above made by
9
+ * Lucas Nestler, added to his https://github.com/ClashLuke/HeavyBall implementation.
10
+ * Omead Pooladzandi https://github.com/opooladz
11
+
12
+ The above work drew from https://github.com/lixilinx/psgd_torch by Xi-Lin Li
13
+
14
+ This `timm` impl
15
+ * works with a wider variety of torch versions
16
+ * fixes some checkpoint save/restore (resume issues)
17
+ * adds decoupled weight-decay option
7
18
8
19
"""
9
20
import logging
30
41
except AttributeError :
31
42
has_dynamo = False
32
43
44
+ from ._types import ParamsT
45
+
33
46
_logger = logging .getLogger (__name__ )
34
47
35
48
@@ -85,7 +98,7 @@ class Kron(torch.optim.Optimizer):
85
98
86
99
def __init__ (
87
100
self ,
88
- params ,
101
+ params : ParamsT ,
89
102
lr : float = 0.001 ,
90
103
momentum : float = 0.9 ,
91
104
weight_decay : float = 0.0 ,
@@ -94,6 +107,8 @@ def __init__(
94
107
min_ndim_triangular : int = 2 ,
95
108
memory_save_mode : Optional [str ] = None ,
96
109
momentum_into_precond_update : bool = True ,
110
+ precond_lr : float = 0.1 ,
111
+ precond_init_scale : float = 1.0 ,
97
112
mu_dtype : Optional [torch .dtype ] = None ,
98
113
precond_dtype : Optional [torch .dtype ] = None ,
99
114
decoupled_decay : bool = False ,
@@ -119,8 +134,8 @@ def __init__(
119
134
min_ndim_triangular = min_ndim_triangular ,
120
135
memory_save_mode = memory_save_mode ,
121
136
momentum_into_precond_update = momentum_into_precond_update ,
122
- precond_lr = 0.1 , # precond lr hardcoded to 0.1
123
- precond_init_scale = 1.0 , # precond init scale hardcoded to 1.0
137
+ precond_lr = precond_lr ,
138
+ precond_init_scale = precond_init_scale ,
124
139
mu_dtype = mu_dtype ,
125
140
precond_dtype = precond_dtype ,
126
141
decoupled_decay = decoupled_decay ,
0 commit comments