Skip to content

Commit 1a4ffe4

Browse files
authored
fix: refactor layer norm converter with INormalization Layer (#2755)
1 parent 4fe5feb commit 1a4ffe4

File tree

2 files changed

+81
-102
lines changed

2 files changed

+81
-102
lines changed

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 20 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1010
from torch_tensorrt.dynamo.conversion.converter_utils import (
1111
cast_trt_tensor,
12+
get_axes_for_reduce_op,
1213
get_positive_dim,
1314
get_trt_tensor,
1415
to_numpy,
@@ -105,102 +106,30 @@ def layer_norm(
105106
cudnn_enable: bool,
106107
return_mean_rstd: bool,
107108
) -> Union[TRTTensor, Tuple[TRTTensor, torch.Tensor, torch.Tensor]]:
108-
if weight is None:
109-
weight = to_numpy(1.0)
110-
111-
if bias is None:
112-
bias = to_numpy(0.0)
113-
114-
shape = weight.shape
115-
gamma = to_numpy(weight).reshape(shape)
116-
beta = to_numpy(bias).reshape(shape)
117-
118-
dims = list(range(len(input.shape) - len(shape), len(input.shape)))
119-
120-
# E[x]
121-
mean_expected_trt = impl.reduce.mean(
122-
ctx, target, source_ir, f"{name}_mean_expected", input, dims, True
123-
)
124-
125-
# X-E[x]
126-
sub_trt = impl.elementwise.sub(
127-
ctx,
128-
target,
129-
source_ir,
130-
f"{name}_sub",
131-
input,
132-
mean_expected_trt,
133-
)
134-
135-
# Variance = mean(pow(x_sub_mean, 2))
136-
pow_trt = get_trt_tensor(ctx, 2, f"{name}_power", np.float32)
137-
pow_var = impl.elementwise.pow(
138-
ctx,
139-
target,
140-
source_ir,
141-
f"{name}_pow_var",
142-
sub_trt,
143-
pow_trt,
144-
)
145-
mean_trt = impl.reduce.mean(
146-
ctx, target, source_ir, f"{name}_mean", pow_var, dims, True
147-
)
148-
149-
# sqrt((var + eps))
150-
eps_trt = get_trt_tensor(ctx, eps, f"{name}_eps", np.float32)
151-
add_trt = impl.elementwise.add(
152-
ctx,
153-
target,
154-
source_ir,
155-
f"{name}_add",
156-
mean_trt,
157-
eps_trt,
158-
)
159-
sqrt_trt = impl.unary.sqrt(
160-
ctx,
161-
target,
162-
source_ir,
163-
f"{name}_sqrt",
164-
add_trt,
165-
)
166-
167-
# (X - E[X]) / sqrt((var + eps))
168-
div_trt = impl.elementwise.div(
169-
ctx,
170-
target,
171-
source_ir,
172-
f"{name}_div",
173-
sub_trt,
174-
sqrt_trt,
175-
)
176-
177-
gamma_trt = get_trt_tensor(ctx, weight, f"{name}_gamma")
178-
beta_trt = get_trt_tensor(ctx, bias, f"{name}_beta")
179-
180-
# y * gamma + beta
181-
scaled_y = impl.elementwise.mul(
182-
ctx,
183-
target,
184-
source_ir,
185-
f"{name}_mul_gamma",
186-
div_trt,
187-
gamma_trt,
188-
)
109+
dims = list(range(len(input.shape) - len(normalized_shape), len(input.shape)))
110+
axes = get_axes_for_reduce_op(dims)
111+
112+
weight = get_trt_tensor(ctx, weight, f"{name}_weight")
113+
bias = get_trt_tensor(ctx, bias, f"{name}_bias")
114+
if tuple(input.shape) != tuple(weight.shape):
115+
weight = impl.slice.expand(
116+
ctx, target, source_ir, f"{name}_expand_weight", weight, input.shape
117+
)
118+
if tuple(input.shape) != tuple(bias.shape):
119+
bias = impl.slice.expand(
120+
ctx, target, source_ir, f"{name}_expand_bias", bias, input.shape
121+
)
189122

190-
output = impl.elementwise.add(
191-
ctx,
192-
target,
193-
source_ir,
194-
f"{name}_add_beta",
195-
scaled_y,
196-
beta_trt,
197-
)
123+
layer_norm = ctx.net.add_normalization(input, weight, bias, axes)
124+
layer_norm.epsilon = eps
125+
layer_norm.compute_precision = input.dtype
126+
set_layer_name(layer_norm, target, f"{name}_layer_norm", source_ir)
198127

199128
if return_mean_rstd:
200129
# return fake mean and rstd for now
201-
return output, None, None
130+
return layer_norm.get_output(0), None, None
202131

203-
return output
132+
return layer_norm.get_output(0)
204133

205134

206135
def native_group_norm(

tests/py/dynamo/conversion/test_layer_norm_aten.py

Lines changed: 61 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,75 @@
11
import torch
2+
from parameterized import parameterized
23
from torch.testing._internal.common_utils import run_tests
34
from torch_tensorrt import Input
45

56
from .harness import DispatchTestCase
67

78

89
class TestLayerNormConverter(DispatchTestCase):
9-
def test_layer_norm(self):
10+
@parameterized.expand(
11+
[
12+
(
13+
(5, 3, 2, 4),
14+
[
15+
4,
16+
],
17+
),
18+
((5, 3, 2, 4), [2, 4]),
19+
((5, 3, 2, 4), [3, 2, 4]),
20+
((5, 3, 2, 4), [5, 3, 2, 4]),
21+
]
22+
)
23+
def test_layer_norm(self, input_shape, normalized_shape, eps=1e-05):
1024
class LayerNorm(torch.nn.Module):
1125
def forward(self, x):
1226
return torch.ops.aten.layer_norm.default(
1327
x,
14-
torch.tensor([3, 224, 224]),
15-
torch.ones((3, 224, 224)),
16-
torch.zeros((3, 224, 224)),
17-
1e-05,
18-
True,
28+
normalized_shape,
29+
torch.randn(normalized_shape),
30+
torch.randn(normalized_shape),
31+
eps,
1932
)
2033

21-
inputs = [torch.randn(1, 3, 224, 224)]
34+
inputs = [torch.randn(input_shape)]
2235
self.run_test(
2336
LayerNorm(),
2437
inputs,
2538
)
2639

2740

2841
class TestNativeLayerNormConverter(DispatchTestCase):
29-
def test_layer_norm(self):
42+
@parameterized.expand(
43+
[
44+
(
45+
(5, 3, 2, 4),
46+
[
47+
4,
48+
],
49+
),
50+
((5, 3, 2, 4), [2, 4]),
51+
((5, 3, 2, 4), [3, 2, 4]),
52+
((5, 3, 2, 4), [5, 3, 2, 4]),
53+
]
54+
)
55+
def test_layer_norm(self, input_shape, normalized_shape, eps=1e-05):
56+
class LayerNorm(torch.nn.Module):
57+
def forward(self, x):
58+
return torch.ops.aten.native_layer_norm.default(
59+
x,
60+
normalized_shape,
61+
torch.randn(normalized_shape),
62+
torch.randn(normalized_shape),
63+
eps,
64+
)[0]
65+
66+
inputs = [torch.randn(input_shape)]
67+
self.run_test(
68+
LayerNorm(),
69+
inputs,
70+
)
71+
72+
def test_layernorm_with_dynamic_shape(self):
3073
class LayerNorm(torch.nn.Module):
3174
def forward(self, x):
3275
return torch.ops.aten.native_layer_norm.default(
@@ -37,10 +80,17 @@ def forward(self, x):
3780
1e-05,
3881
)[0]
3982

40-
inputs = [torch.randn(1, 3, 224, 224)]
41-
self.run_test(
83+
input_specs = [
84+
Input(
85+
shape=(-1, 3, 224, 224),
86+
dtype=torch.float32,
87+
shape_ranges=[((1, 3, 224, 224), (5, 3, 224, 224), (10, 3, 224, 224))],
88+
),
89+
]
90+
91+
self.run_test_with_dynamic_shape(
4292
LayerNorm(),
43-
inputs,
93+
input_specs,
4494
)
4595

4696

0 commit comments

Comments
 (0)