Skip to content

Commit 5809c2f

Browse files
committed
Use torch F.rms_norm when possible, select fast vs normal paths appropriately and test with torchscript
1 parent e0cacbf commit 5809c2f

File tree

2 files changed

+64
-18
lines changed

2 files changed

+64
-18
lines changed

timm/layers/fast_norm.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
has_apex_rmsnorm = False
2525

2626

27+
has_torch_rms_norm = hasattr(F, 'rms_norm')
28+
2729
# fast (ie lower precision LN) can be disabled with this flag if issues crop up
2830
_USE_FAST_NORM = False # defaulting to False for now
2931

@@ -75,7 +77,6 @@ def fast_group_norm(
7577
if is_autocast_enabled(x.device.type):
7678
# normally native AMP casts GN inputs to float32
7779
# here we use the low precision autocast dtype
78-
# FIXME what to do re CPU autocast?
7980
dt = get_autocast_dtype(x.device.type)
8081
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None
8182

@@ -101,14 +102,12 @@ def fast_layer_norm(
101102
# normally native AMP casts LN inputs to float32
102103
# apex LN does not, this is behaving like Apex
103104
dt = get_autocast_dtype(x.device.type)
104-
# FIXME what to do re CPU autocast?
105105
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None
106106

107107
with torch.amp.autocast(device_type=x.device.type, enabled=False):
108108
return F.layer_norm(x, normalized_shape, weight, bias, eps)
109109

110110

111-
112111
def rms_norm(
113112
x: torch.Tensor,
114113
normalized_shape: List[int],
@@ -148,8 +147,19 @@ def fast_rms_norm(
148147
else:
149148
return fused_rms_norm_affine(x, weight, normalized_shape, eps)
150149

151-
# fallback
152-
return rms_norm(x, normalized_shape, weight, eps)
150+
if is_autocast_enabled(x.device.type):
151+
# normally native AMP casts LN inputs to float32
152+
# apex LN does not, this is behaving like Apex
153+
dt = get_autocast_dtype(x.device.type)
154+
x, weight = x.to(dt), weight.to(dt)
155+
156+
with torch.amp.autocast(device_type=x.device.type, enabled=False):
157+
if has_torch_rms_norm:
158+
x = F.rms_norm(x, normalized_shape, weight, eps)
159+
else:
160+
x = rms_norm(x, normalized_shape, weight, eps)
161+
162+
return x
153163

154164

155165
def simple_norm(

timm/layers/norm.py

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,24 @@
1111
import torch.nn as nn
1212
import torch.nn.functional as F
1313

14-
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, fast_simple_norm
14+
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, fast_simple_norm, simple_norm
15+
16+
try:
17+
from torch.nn.functional import rms_norm
18+
except ImportError:
19+
from .fast_norm import rms_norm
1520

1621

1722
class GroupNorm(nn.GroupNorm):
23+
_fast_norm: torch.jit.Final[bool]
24+
1825
def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True):
1926
# NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN
2027
super().__init__(num_groups, num_channels, eps=eps, affine=affine)
21-
self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
28+
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
2229

2330
def forward(self, x):
24-
if self.fast_norm:
31+
if self._fast_norm:
2532
return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
2633
else:
2734
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
@@ -31,13 +38,14 @@ class GroupNorm1(nn.GroupNorm):
3138
""" Group Normalization with 1 group.
3239
Input: tensor in shape [B, C, *]
3340
"""
41+
_fast_norm: torch.jit.Final[bool]
3442

3543
def __init__(self, num_channels, **kwargs):
3644
super().__init__(1, num_channels, **kwargs)
37-
self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
45+
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
3846

3947
def forward(self, x: torch.Tensor) -> torch.Tensor:
40-
if self.fast_norm:
48+
if self._fast_norm:
4149
return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
4250
else:
4351
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
@@ -46,6 +54,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
4654
class LayerNorm(nn.LayerNorm):
4755
""" LayerNorm w/ fast norm option
4856
"""
57+
_fast_norm: torch.jit.Final[bool]
58+
4959
def __init__(self, num_channels, eps=1e-6, affine=True):
5060
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
5161
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
@@ -60,6 +70,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
6070

6171
class LayerNorm2d(nn.LayerNorm):
6272
""" LayerNorm for channels of '2D' spatial NCHW tensors """
73+
_fast_norm: torch.jit.Final[bool]
74+
6375
def __init__(self, num_channels, eps=1e-6, affine=True):
6476
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
6577
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
@@ -121,10 +133,11 @@ def forward(self, x) -> torch.Tensor:
121133
class RmsNorm(nn.Module):
122134
""" RmsNorm w/ fast (apex) norm if available
123135
"""
124-
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
136+
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
125137
normalized_shape: Tuple[int, ...]
126138
eps: float
127139
elementwise_affine: bool
140+
_fast_norm: bool
128141

129142
def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
130143
factory_kwargs = {'device': device, 'dtype': dtype}
@@ -136,6 +149,8 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) ->
136149
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
137150
self.eps = eps
138151
self.elementwise_affine = affine
152+
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
153+
139154
if self.elementwise_affine:
140155
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
141156
else:
@@ -150,17 +165,21 @@ def reset_parameters(self) -> None:
150165
def forward(self, x: torch.Tensor) -> torch.Tensor:
151166
# NOTE fast norm fallback needs our rms norm impl, so both paths through here.
152167
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
153-
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
168+
if self._fast_norm:
169+
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
170+
else:
171+
x = rms_norm(x, self.normalized_shape, self.weight, self.eps)
154172
return x
155173

156174

157175
class RmsNorm2d(nn.Module):
158176
""" RmsNorm w/ fast (apex) norm if available
159177
"""
160-
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
178+
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
161179
normalized_shape: Tuple[int, ...]
162180
eps: float
163181
elementwise_affine: bool
182+
_fast_norm: bool
164183

165184
def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
166185
factory_kwargs = {'device': device, 'dtype': dtype}
@@ -172,6 +191,8 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) ->
172191
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
173192
self.eps = eps
174193
self.elementwise_affine = affine
194+
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
195+
175196
if self.elementwise_affine:
176197
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
177198
else:
@@ -187,18 +208,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
187208
x = x.permute(0, 2, 3, 1)
188209
# NOTE fast norm fallback needs our rms norm impl, so both paths through here.
189210
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
190-
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
211+
if self._fast_norm:
212+
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
213+
else:
214+
x = rms_norm(x, self.normalized_shape, self.weight, self.eps)
191215
x = x.permute(0, 3, 1, 2)
192216
return x
193217

194218

195219
class SimpleNorm(nn.Module):
196220
""" SimpleNorm (x / std(x))
197221
"""
198-
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
222+
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
199223
normalized_shape: Tuple[int, ...]
200224
eps: float
201225
elementwise_affine: bool
226+
_fast_norm: bool
202227

203228
def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
204229
factory_kwargs = {'device': device, 'dtype': dtype}
@@ -210,6 +235,8 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) ->
210235
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
211236
self.eps = eps
212237
self.elementwise_affine = affine
238+
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
239+
213240
if self.elementwise_affine:
214241
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
215242
else:
@@ -222,17 +249,21 @@ def reset_parameters(self) -> None:
222249
nn.init.ones_(self.weight)
223250

224251
def forward(self, x: torch.Tensor) -> torch.Tensor:
225-
x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps)
252+
if self._fast_norm:
253+
x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps)
254+
else:
255+
x = simple_norm(x, self.normalized_shape, self.weight, self.eps)
226256
return x
227257

228258

229259
class SimpleNorm2d(nn.Module):
230260
""" SimpleNorm for NCHW tensors
231261
"""
232-
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
262+
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
233263
normalized_shape: Tuple[int, ...]
234264
eps: float
235265
elementwise_affine: bool
266+
_fast_norm: bool
236267

237268
def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
238269
factory_kwargs = {'device': device, 'dtype': dtype}
@@ -244,6 +275,8 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) ->
244275
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
245276
self.eps = eps
246277
self.elementwise_affine = affine
278+
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
279+
247280
if self.elementwise_affine:
248281
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
249282
else:
@@ -257,6 +290,9 @@ def reset_parameters(self) -> None:
257290

258291
def forward(self, x: torch.Tensor) -> torch.Tensor:
259292
x = x.permute(0, 2, 3, 1)
260-
x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps)
293+
if self._fast_norm:
294+
x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps)
295+
else:
296+
x = simple_norm(x, self.normalized_shape, self.weight, self.eps)
261297
x = x.permute(0, 3, 1, 2)
262298
return x

0 commit comments

Comments
 (0)