Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 80a0205

Browse files
committedJan 27, 2025·
Move opt_einsum import back out of class __init__
1 parent 71d1741 commit 80a0205

File tree

1 file changed

+16
-14
lines changed

1 file changed

+16
-14
lines changed
 

‎timm/optim/kron.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,15 @@
1414

1515
import numpy as np
1616
import torch
17-
17+
try:
18+
# NOTE opt_einsum needed to avoid blowing up memory with einsum ops
19+
import opt_einsum
20+
import torch.backends.opt_einsum
21+
torch.backends.opt_einsum.enabled = True
22+
torch.backends.opt_einsum.strategy = "auto-hq"
23+
has_opt_einsum = True
24+
except ImportError:
25+
has_opt_einsum = False
1826

1927
try:
2028
torch._dynamo.config.cache_size_limit = 1_000_000
@@ -26,11 +34,11 @@
2634

2735

2836
def precond_update_prob_schedule(
29-
n: float,
30-
max_prob: float = 1.0,
31-
min_prob: float = 0.03,
32-
decay: float = 0.001,
33-
flat_start: float = 500,
37+
n: float,
38+
max_prob: float = 1.0,
39+
min_prob: float = 0.03,
40+
decay: float = 0.001,
41+
flat_start: float = 500,
3442
) -> torch.Tensor:
3543
"""Anneal preconditioner update probability during beginning of training.
3644
@@ -92,14 +100,8 @@ def __init__(
92100
flatten_dim: bool = False,
93101
deterministic: bool = False,
94102
):
95-
try:
96-
# NOTE opt_einsum needed to avoid blowing up memory with einsum ops
97-
import opt_einsum
98-
opt_einsum.enabled = True
99-
opt_einsum.strategy = "auto-hq"
100-
import torch.backends.opt_einsum
101-
except ImportError:
102-
warnings.warn("It is highly recommended to have 'opt_einsum' installed for this optimizer." )
103+
if not has_opt_einsum:
104+
warnings.warn("It is highly recommended to have 'opt_einsum' installed for this optimizer.")
103105

104106
if not 0.0 <= lr:
105107
raise ValueError(f"Invalid learning rate: {lr}")

0 commit comments

Comments
 (0)
Please sign in to comment.