File tree Expand file tree Collapse file tree 1 file changed +16
-14
lines changed Expand file tree Collapse file tree 1 file changed +16
-14
lines changed Original file line number Diff line number Diff line change 14
14
15
15
import numpy as np
16
16
import torch
17
-
17
+ try :
18
+ # NOTE opt_einsum needed to avoid blowing up memory with einsum ops
19
+ import opt_einsum
20
+ import torch .backends .opt_einsum
21
+ torch .backends .opt_einsum .enabled = True
22
+ torch .backends .opt_einsum .strategy = "auto-hq"
23
+ has_opt_einsum = True
24
+ except ImportError :
25
+ has_opt_einsum = False
18
26
19
27
try :
20
28
torch ._dynamo .config .cache_size_limit = 1_000_000
26
34
27
35
28
36
def precond_update_prob_schedule (
29
- n : float ,
30
- max_prob : float = 1.0 ,
31
- min_prob : float = 0.03 ,
32
- decay : float = 0.001 ,
33
- flat_start : float = 500 ,
37
+ n : float ,
38
+ max_prob : float = 1.0 ,
39
+ min_prob : float = 0.03 ,
40
+ decay : float = 0.001 ,
41
+ flat_start : float = 500 ,
34
42
) -> torch .Tensor :
35
43
"""Anneal preconditioner update probability during beginning of training.
36
44
@@ -92,14 +100,8 @@ def __init__(
92
100
flatten_dim : bool = False ,
93
101
deterministic : bool = False ,
94
102
):
95
- try :
96
- # NOTE opt_einsum needed to avoid blowing up memory with einsum ops
97
- import opt_einsum
98
- opt_einsum .enabled = True
99
- opt_einsum .strategy = "auto-hq"
100
- import torch .backends .opt_einsum
101
- except ImportError :
102
- warnings .warn ("It is highly recommended to have 'opt_einsum' installed for this optimizer." )
103
+ if not has_opt_einsum :
104
+ warnings .warn ("It is highly recommended to have 'opt_einsum' installed for this optimizer." )
103
105
104
106
if not 0.0 <= lr :
105
107
raise ValueError (f"Invalid learning rate: { lr } " )
You can’t perform that action at this time.
0 commit comments