Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit b66ef4c

Browse files
committedJan 28, 2025·
Prep Kron for merge, add detail to attributions note.
1 parent 80a0205 commit b66ef4c

File tree

1 file changed

+23
-8
lines changed

1 file changed

+23
-8
lines changed
 

‎timm/optim/kron.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,20 @@
1-
""" PyTorch Implementation of the Kron PSGD optimizer
1+
""" PyTorch Implementation of the Kron (PSGD) optimizer
22
3-
FIXME attribution
4-
* https://github.com/evanatyourservice/kron_torch (direct source)
5-
* https://github.com/lixilinx/psgd_torch (original)
6-
* https://github.com/ClashLuke/HeavyBall (added improvements)
3+
This is a PSGD optimizer using a Kronecker-factored preconditioner.
4+
5+
This impl was adapted from https://github.com/evanatyourservice/kron_torch
6+
by Evan Walters, licensed CC-BY-4.0.
7+
8+
Improvements, fixes to above made by
9+
* Lucas Nestler, added to his https://github.com/ClashLuke/HeavyBall implementation.
10+
* Omead Pooladzandi https://github.com/opooladz
11+
12+
The above work drew from https://github.com/lixilinx/psgd_torch by Xi-Lin Li
13+
14+
This `timm` impl
15+
* works with a wider variety of torch versions
16+
* fixes some checkpoint save/restore (resume issues)
17+
* adds decoupled weight-decay option
718
819
"""
920
import logging
@@ -30,6 +41,8 @@
3041
except AttributeError:
3142
has_dynamo = False
3243

44+
from ._types import ParamsT
45+
3346
_logger = logging.getLogger(__name__)
3447

3548

@@ -85,7 +98,7 @@ class Kron(torch.optim.Optimizer):
8598

8699
def __init__(
87100
self,
88-
params,
101+
params: ParamsT,
89102
lr: float = 0.001,
90103
momentum: float = 0.9,
91104
weight_decay: float = 0.0,
@@ -94,6 +107,8 @@ def __init__(
94107
min_ndim_triangular: int = 2,
95108
memory_save_mode: Optional[str] = None,
96109
momentum_into_precond_update: bool = True,
110+
precond_lr: float = 0.1,
111+
precond_init_scale: float = 1.0,
97112
mu_dtype: Optional[torch.dtype] = None,
98113
precond_dtype: Optional[torch.dtype] = None,
99114
decoupled_decay: bool = False,
@@ -119,8 +134,8 @@ def __init__(
119134
min_ndim_triangular=min_ndim_triangular,
120135
memory_save_mode=memory_save_mode,
121136
momentum_into_precond_update=momentum_into_precond_update,
122-
precond_lr=0.1, # precond lr hardcoded to 0.1
123-
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
137+
precond_lr=precond_lr,
138+
precond_init_scale=precond_init_scale,
124139
mu_dtype=mu_dtype,
125140
precond_dtype=precond_dtype,
126141
decoupled_decay=decoupled_decay,

0 commit comments

Comments
 (0)
Please sign in to comment.