11
11
import torch .nn as nn
12
12
import torch .nn .functional as F
13
13
14
- from .fast_norm import is_fast_norm , fast_group_norm , fast_layer_norm , fast_rms_norm , fast_simple_norm
14
+ from .fast_norm import is_fast_norm , fast_group_norm , fast_layer_norm , fast_rms_norm , fast_simple_norm , simple_norm
15
+
16
+ try :
17
+ from torch .nn .functional import rms_norm
18
+ except ImportError :
19
+ from .fast_norm import rms_norm
15
20
16
21
17
22
class GroupNorm (nn .GroupNorm ):
23
+ _fast_norm : torch .jit .Final [bool ]
24
+
18
25
def __init__ (self , num_channels , num_groups = 32 , eps = 1e-5 , affine = True ):
19
26
# NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN
20
27
super ().__init__ (num_groups , num_channels , eps = eps , affine = affine )
21
- self .fast_norm = is_fast_norm () # can't script unless we have these flags here (no globals)
28
+ self ._fast_norm = is_fast_norm () # can't script unless we have these flags here (no globals)
22
29
23
30
def forward (self , x ):
24
- if self .fast_norm :
31
+ if self ._fast_norm :
25
32
return fast_group_norm (x , self .num_groups , self .weight , self .bias , self .eps )
26
33
else :
27
34
return F .group_norm (x , self .num_groups , self .weight , self .bias , self .eps )
@@ -31,13 +38,14 @@ class GroupNorm1(nn.GroupNorm):
31
38
""" Group Normalization with 1 group.
32
39
Input: tensor in shape [B, C, *]
33
40
"""
41
+ _fast_norm : torch .jit .Final [bool ]
34
42
35
43
def __init__ (self , num_channels , ** kwargs ):
36
44
super ().__init__ (1 , num_channels , ** kwargs )
37
- self .fast_norm = is_fast_norm () # can't script unless we have these flags here (no globals)
45
+ self ._fast_norm = is_fast_norm () # can't script unless we have these flags here (no globals)
38
46
39
47
def forward (self , x : torch .Tensor ) -> torch .Tensor :
40
- if self .fast_norm :
48
+ if self ._fast_norm :
41
49
return fast_group_norm (x , self .num_groups , self .weight , self .bias , self .eps )
42
50
else :
43
51
return F .group_norm (x , self .num_groups , self .weight , self .bias , self .eps )
@@ -46,6 +54,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
46
54
class LayerNorm (nn .LayerNorm ):
47
55
""" LayerNorm w/ fast norm option
48
56
"""
57
+ _fast_norm : torch .jit .Final [bool ]
58
+
49
59
def __init__ (self , num_channels , eps = 1e-6 , affine = True ):
50
60
super ().__init__ (num_channels , eps = eps , elementwise_affine = affine )
51
61
self ._fast_norm = is_fast_norm () # can't script unless we have these flags here (no globals)
@@ -60,6 +70,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
60
70
61
71
class LayerNorm2d (nn .LayerNorm ):
62
72
""" LayerNorm for channels of '2D' spatial NCHW tensors """
73
+ _fast_norm : torch .jit .Final [bool ]
74
+
63
75
def __init__ (self , num_channels , eps = 1e-6 , affine = True ):
64
76
super ().__init__ (num_channels , eps = eps , elementwise_affine = affine )
65
77
self ._fast_norm = is_fast_norm () # can't script unless we have these flags here (no globals)
@@ -121,10 +133,11 @@ def forward(self, x) -> torch.Tensor:
121
133
class RmsNorm (nn .Module ):
122
134
""" RmsNorm w/ fast (apex) norm if available
123
135
"""
124
- __constants__ = ['normalized_shape' , 'eps' , 'elementwise_affine' ]
136
+ __constants__ = ['normalized_shape' , 'eps' , 'elementwise_affine' , '_fast_norm' ]
125
137
normalized_shape : Tuple [int , ...]
126
138
eps : float
127
139
elementwise_affine : bool
140
+ _fast_norm : bool
128
141
129
142
def __init__ (self , channels , eps = 1e-6 , affine = True , device = None , dtype = None ) -> None :
130
143
factory_kwargs = {'device' : device , 'dtype' : dtype }
@@ -136,6 +149,8 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) ->
136
149
self .normalized_shape = tuple (normalized_shape ) # type: ignore[arg-type]
137
150
self .eps = eps
138
151
self .elementwise_affine = affine
152
+ self ._fast_norm = is_fast_norm () # can't script unless we have these flags here (no globals)
153
+
139
154
if self .elementwise_affine :
140
155
self .weight = nn .Parameter (torch .empty (self .normalized_shape , ** factory_kwargs ))
141
156
else :
@@ -150,17 +165,21 @@ def reset_parameters(self) -> None:
150
165
def forward (self , x : torch .Tensor ) -> torch .Tensor :
151
166
# NOTE fast norm fallback needs our rms norm impl, so both paths through here.
152
167
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
153
- x = fast_rms_norm (x , self .normalized_shape , self .weight , self .eps )
168
+ if self ._fast_norm :
169
+ x = fast_rms_norm (x , self .normalized_shape , self .weight , self .eps )
170
+ else :
171
+ x = rms_norm (x , self .normalized_shape , self .weight , self .eps )
154
172
return x
155
173
156
174
157
175
class RmsNorm2d (nn .Module ):
158
176
""" RmsNorm w/ fast (apex) norm if available
159
177
"""
160
- __constants__ = ['normalized_shape' , 'eps' , 'elementwise_affine' ]
178
+ __constants__ = ['normalized_shape' , 'eps' , 'elementwise_affine' , '_fast_norm' ]
161
179
normalized_shape : Tuple [int , ...]
162
180
eps : float
163
181
elementwise_affine : bool
182
+ _fast_norm : bool
164
183
165
184
def __init__ (self , channels , eps = 1e-6 , affine = True , device = None , dtype = None ) -> None :
166
185
factory_kwargs = {'device' : device , 'dtype' : dtype }
@@ -172,6 +191,8 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) ->
172
191
self .normalized_shape = tuple (normalized_shape ) # type: ignore[arg-type]
173
192
self .eps = eps
174
193
self .elementwise_affine = affine
194
+ self ._fast_norm = is_fast_norm () # can't script unless we have these flags here (no globals)
195
+
175
196
if self .elementwise_affine :
176
197
self .weight = nn .Parameter (torch .empty (self .normalized_shape , ** factory_kwargs ))
177
198
else :
@@ -187,18 +208,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
187
208
x = x .permute (0 , 2 , 3 , 1 )
188
209
# NOTE fast norm fallback needs our rms norm impl, so both paths through here.
189
210
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
190
- x = fast_rms_norm (x , self .normalized_shape , self .weight , self .eps )
211
+ if self ._fast_norm :
212
+ x = fast_rms_norm (x , self .normalized_shape , self .weight , self .eps )
213
+ else :
214
+ x = rms_norm (x , self .normalized_shape , self .weight , self .eps )
191
215
x = x .permute (0 , 3 , 1 , 2 )
192
216
return x
193
217
194
218
195
219
class SimpleNorm (nn .Module ):
196
220
""" SimpleNorm (x / std(x))
197
221
"""
198
- __constants__ = ['normalized_shape' , 'eps' , 'elementwise_affine' ]
222
+ __constants__ = ['normalized_shape' , 'eps' , 'elementwise_affine' , '_fast_norm' ]
199
223
normalized_shape : Tuple [int , ...]
200
224
eps : float
201
225
elementwise_affine : bool
226
+ _fast_norm : bool
202
227
203
228
def __init__ (self , channels , eps = 1e-6 , affine = True , device = None , dtype = None ) -> None :
204
229
factory_kwargs = {'device' : device , 'dtype' : dtype }
@@ -210,6 +235,8 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) ->
210
235
self .normalized_shape = tuple (normalized_shape ) # type: ignore[arg-type]
211
236
self .eps = eps
212
237
self .elementwise_affine = affine
238
+ self ._fast_norm = is_fast_norm () # can't script unless we have these flags here (no globals)
239
+
213
240
if self .elementwise_affine :
214
241
self .weight = nn .Parameter (torch .empty (self .normalized_shape , ** factory_kwargs ))
215
242
else :
@@ -222,17 +249,21 @@ def reset_parameters(self) -> None:
222
249
nn .init .ones_ (self .weight )
223
250
224
251
def forward (self , x : torch .Tensor ) -> torch .Tensor :
225
- x = fast_simple_norm (x , self .normalized_shape , self .weight , self .eps )
252
+ if self ._fast_norm :
253
+ x = fast_simple_norm (x , self .normalized_shape , self .weight , self .eps )
254
+ else :
255
+ x = simple_norm (x , self .normalized_shape , self .weight , self .eps )
226
256
return x
227
257
228
258
229
259
class SimpleNorm2d (nn .Module ):
230
260
""" SimpleNorm for NCHW tensors
231
261
"""
232
- __constants__ = ['normalized_shape' , 'eps' , 'elementwise_affine' ]
262
+ __constants__ = ['normalized_shape' , 'eps' , 'elementwise_affine' , '_fast_norm' ]
233
263
normalized_shape : Tuple [int , ...]
234
264
eps : float
235
265
elementwise_affine : bool
266
+ _fast_norm : bool
236
267
237
268
def __init__ (self , channels , eps = 1e-6 , affine = True , device = None , dtype = None ) -> None :
238
269
factory_kwargs = {'device' : device , 'dtype' : dtype }
@@ -244,6 +275,8 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) ->
244
275
self .normalized_shape = tuple (normalized_shape ) # type: ignore[arg-type]
245
276
self .eps = eps
246
277
self .elementwise_affine = affine
278
+ self ._fast_norm = is_fast_norm () # can't script unless we have these flags here (no globals)
279
+
247
280
if self .elementwise_affine :
248
281
self .weight = nn .Parameter (torch .empty (self .normalized_shape , ** factory_kwargs ))
249
282
else :
@@ -257,6 +290,9 @@ def reset_parameters(self) -> None:
257
290
258
291
def forward (self , x : torch .Tensor ) -> torch .Tensor :
259
292
x = x .permute (0 , 2 , 3 , 1 )
260
- x = fast_simple_norm (x , self .normalized_shape , self .weight , self .eps )
293
+ if self ._fast_norm :
294
+ x = fast_simple_norm (x , self .normalized_shape , self .weight , self .eps )
295
+ else :
296
+ x = simple_norm (x , self .normalized_shape , self .weight , self .eps )
261
297
x = x .permute (0 , 3 , 1 , 2 )
262
298
return x
0 commit comments