Skip to content

Commit d9ef0bc

Browse files
authored
Enable linear+sigmoid, linear+silu, linear+swish fusion (#541)
* Enable linear+sigmoid, linear+silu, linear+swish fusion
1 parent f6cf30a commit d9ef0bc

File tree

5 files changed

+210
-4
lines changed

5 files changed

+210
-4
lines changed

intel_extension_for_pytorch/csrc/jit/cpu/kernels/LinearPacked.cpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,24 @@ at::Tensor linear_gelu_run(
5757
return op_context->run(input, ideep::attr_t::fuse_gelu());
5858
}
5959

60+
at::Tensor linear_sigmoid_run(
61+
const at::Tensor& input,
62+
const c10::intrusive_ptr<LinearOpContext>& op_context) {
63+
IPEX_RECORD_FUNCTION(
64+
"ipex_prepack::linear_sigmoid_run", std::vector<c10::IValue>({}));
65+
66+
return op_context->run(input, ideep::attr_t::fuse_sigmoid());
67+
}
68+
69+
at::Tensor linear_swish_run(
70+
const at::Tensor& input,
71+
const c10::intrusive_ptr<LinearOpContext>& op_context) {
72+
IPEX_RECORD_FUNCTION(
73+
"ipex_prepack::linear_swish_run", std::vector<c10::IValue>({}));
74+
75+
return op_context->run(input, ideep::attr_t::fuse_swish());
76+
}
77+
6078
at::Tensor linear_add_run(
6179
const at::Tensor& input,
6280
at::Tensor& accumu,
@@ -125,4 +143,4 @@ at::Tensor& run(
125143
} // namespace linear
126144
} // namespace detail
127145
} // namespace cpu
128-
} // namespace torch_ipex
146+
} // namespace torch_ipex

intel_extension_for_pytorch/csrc/jit/cpu/kernels/LinearPacked.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,14 @@ at::Tensor linear_gelu_run(
2929
const at::Tensor& input,
3030
const c10::intrusive_ptr<LinearOpContext>& op_context);
3131

32+
at::Tensor linear_sigmoid_run(
33+
const at::Tensor& input,
34+
const c10::intrusive_ptr<LinearOpContext>& op_context);
35+
36+
at::Tensor linear_swish_run(
37+
const at::Tensor& input,
38+
const c10::intrusive_ptr<LinearOpContext>& op_context);
39+
3240
at::Tensor linear_add_run(
3341
const at::Tensor& input,
3442
at::Tensor& accumu,
@@ -57,4 +65,4 @@ at::Tensor& run(
5765
} // namespace linear
5866
} // namespace detail
5967
} // namespace cpu
60-
} // namespace torch_ipex
68+
} // namespace torch_ipex

intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite_linear.cpp

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,12 @@ void insertPrePackedLinearOp(std::shared_ptr<Graph>& graph) {
9898
}
9999

100100
void fuseLinearWithEltwise(std::shared_ptr<Graph>& graph) {
101-
SubgraphRewriter rewriter_relu, rewriter_gelu;
101+
SubgraphRewriter rewriter_relu, rewriter_gelu, rewriter_silu,
102+
rewriter_sigmoid, rewriter_swish;
102103
std::array<std::string, 2> relu_operators = {"relu", "relu_"};
104+
std::array<std::string, 2> sigmoid_operators = {"sigmoid", "sigmoid_"};
105+
std::array<std::string, 2> silu_operators = {"silu", "silu_"};
106+
std::array<std::string, 2> mul_operators = {"mul", "mul_"};
103107

104108
auto linear_relu_rstring = CodeTemplate(R"(
105109
graph(%input, %weight, %bias, %out_features:int, %in_features:int, %batch_size:int, %weight_is_prepacked:bool):
@@ -127,13 +131,68 @@ void fuseLinearWithEltwise(std::shared_ptr<Graph>& graph) {
127131
%res = ipex_prepack::linear_gelu_run(%input, %packed_weight)
128132
return (%res))";
129133

134+
auto linear_sigmoid_rstring = CodeTemplate(R"(
135+
graph(%input, %weight, %bias, %out_features:int, %in_features:int, %batch_size:int, %weight_is_prepacked:bool):
136+
%packed_weight = ipex_prepack::linear_prepack(%weight, %bias, %out_features, %in_features, %batch_size, %weight_is_prepacked)
137+
%x = ipex_prepack::linear_run(%input, %packed_weight)
138+
%res= aten::${sigmoid}(%x)
139+
return (%res))");
140+
141+
auto linear_silu_rstring = CodeTemplate(R"(
142+
graph(%input, %weight, %bias, %out_features:int, %in_features:int, %batch_size:int, %weight_is_prepacked:bool):
143+
%packed_weight = ipex_prepack::linear_prepack(%weight, %bias, %out_features, %in_features, %batch_size, %weight_is_prepacked)
144+
%x = ipex_prepack::linear_run(%input, %packed_weight)
145+
%res= aten::${silu}(%x)
146+
return (%res))");
147+
148+
auto linear_sigmoid_mul_rstring = CodeTemplate(R"(
149+
graph(%input, %weight, %bias, %out_features:int, %in_features:int, %batch_size:int, %weight_is_prepacked:bool):
150+
%packed_weight = ipex_prepack::linear_prepack(%weight, %bias, %out_features, %in_features, %batch_size, %weight_is_prepacked)
151+
%x = ipex_prepack::linear_run(%input, %packed_weight)
152+
%y = aten::${sigmoid}(%x)
153+
%res = aten::${mul}(%x, %y)
154+
return (%res))");
155+
156+
std::string linear_swish_fused = R"(
157+
graph(%input, %weight, %bias, %out_features:int, %in_features:int, %batch_size:int, %weight_is_prepacked:bool):
158+
%packed_weight = ipex_prepack::linear_prepack(%weight, %bias, %out_features, %in_features, %batch_size, %weight_is_prepacked)
159+
%res = ipex_prepack::linear_swish_run(%input, %packed_weight)
160+
return (%res))";
161+
162+
std::string linear_sigmoid_fused = R"(
163+
graph(%input, %weight, %bias, %out_features:int, %in_features:int, %batch_size:int, %weight_is_prepacked:bool):
164+
%packed_weight = ipex_prepack::linear_prepack(%weight, %bias, %out_features, %in_features, %batch_size, %weight_is_prepacked)
165+
%res = ipex_prepack::linear_sigmoid_run(%input, %packed_weight)
166+
return (%res))";
167+
130168
for (const auto& relu : relu_operators) {
131169
TemplateEnv env;
132170
env.s("relu", relu);
133171
rewriter_relu.RegisterRewritePattern(
134172
linear_relu_rstring.format(env), linear_relu_fused);
135173
}
136174

175+
for (const auto& silu : silu_operators) {
176+
TemplateEnv env;
177+
env.s("silu", silu);
178+
rewriter_silu.RegisterRewritePattern(
179+
linear_silu_rstring.format(env), linear_swish_fused);
180+
}
181+
182+
for (const auto& sigmoid : sigmoid_operators) {
183+
TemplateEnv env;
184+
env.s("sigmoid", sigmoid);
185+
rewriter_sigmoid.RegisterRewritePattern(
186+
linear_sigmoid_rstring.format(env), linear_sigmoid_fused);
187+
for (const auto& mul : mul_operators) {
188+
env.s("mul", mul);
189+
rewriter_swish.RegisterRewritePattern(
190+
linear_sigmoid_mul_rstring.format(env), linear_swish_fused);
191+
}
192+
}
193+
rewriter_silu.runOnGraph(graph);
194+
rewriter_sigmoid.runOnGraph(graph);
195+
rewriter_swish.runOnGraph(graph);
137196
rewriter_gelu.RegisterRewritePattern(linear_gelu, linear_gelu_fused);
138197

139198
rewriter_relu.runOnGraph(graph);

intel_extension_for_pytorch/csrc/jit/cpu/passes/register_dnnl_jit_ops.cpp

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#include "csrc/jit/cpu/kernels/Shuffle.h"
1818
#include "csrc/jit/cpu/kernels/Softmax.h"
1919

20-
2120
namespace torch {
2221
namespace jit {
2322

@@ -365,6 +364,39 @@ RegisterOperators op({
365364
};
366365
},
367366
aliasAnalysisFromSchema()),
367+
Operator(
368+
"ipex_prepack::linear_sigmoid_run(Tensor input, "
369+
"__torch__.torch.classes.ipex_prepack.LinearOpContext W_prepack) "
370+
"-> Tensor",
371+
[](const Node* node) -> Operation {
372+
return [](Stack* stack) {
373+
auto result = linear_sigmoid_run(
374+
(std::move(peek(stack, 0, 2))).toTensor(),
375+
(std::move(peek(stack, 1, 2)))
376+
.toCustomClass<LinearOpContext>());
377+
drop(stack, 2);
378+
pack(stack, std::move(result));
379+
return 0;
380+
};
381+
},
382+
aliasAnalysisFromSchema()),
383+
Operator(
384+
"ipex_prepack::linear_swish_run(Tensor input, "
385+
"__torch__.torch.classes.ipex_prepack.LinearOpContext W_prepack) "
386+
"-> Tensor",
387+
[](const Node* node) -> Operation {
388+
return [](Stack* stack) {
389+
auto result = linear_swish_run(
390+
(std::move(peek(stack, 0, 2))).toTensor(),
391+
(std::move(peek(stack, 1, 2)))
392+
.toCustomClass<LinearOpContext>());
393+
drop(stack, 2);
394+
pack(stack, std::move(result));
395+
return 0;
396+
};
397+
},
398+
aliasAnalysisFromSchema()),
399+
368400
Operator(
369401
"ipex_prepack::linear_add_run(Tensor input, Tensor(a!) accumu, *, "
370402
"Scalar? alpha, "

tests/cpu/test_jit.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,38 @@ def __init__(self, in_channels, out_channels, **kwargs):
351351
def forward(self, x):
352352
return F.gelu(self.linear(x))
353353

354+
class LinearSigmoid(nn.Module):
355+
def __init__(self, in_channels, out_channels, **kwargs):
356+
super(LinearSigmoid, self).__init__()
357+
seed = 2018
358+
torch.manual_seed(seed)
359+
self.linear = nn.Linear(in_channels, out_channels, **kwargs)
360+
361+
def forward(self, x):
362+
return F.sigmoid(self.linear(x))
363+
364+
class LinearSwish(nn.Module):
365+
def __init__(self, in_channels, out_channels, **kwargs):
366+
super(LinearSwish, self).__init__()
367+
seed = 2018
368+
torch.manual_seed(seed)
369+
self.linear = nn.Linear(in_channels, out_channels, **kwargs)
370+
371+
def forward(self, x):
372+
linear_res = self.linear(x)
373+
return F.silu(linear_res)
374+
375+
class LinearSwish_v1(nn.Module):
376+
def __init__(self, in_channels, out_channels, **kwargs):
377+
super(LinearSwish_v1, self).__init__()
378+
seed = 2018
379+
torch.manual_seed(seed)
380+
self.linear = nn.Linear(in_channels, out_channels, **kwargs)
381+
382+
def forward(self, x):
383+
linear_res = self.linear(x)
384+
return torch.mul(linear_res, F.sigmoid(linear_res))
385+
354386
class LinearAdd(nn.Module):
355387
def __init__(self, in_channels, out_channels, **kwargs):
356388
super(LinearAdd, self).__init__()
@@ -2152,6 +2184,63 @@ def test_output_linear_gelu(self):
21522184
torch.rand(32, 3),
21532185
kind_in_graph="ipex_prepack::linear_gelu_run",
21542186
prec=5e-3)
2187+
2188+
def test_output_linear_swish(self):
2189+
self._test_output(
2190+
LinearSwish_v1(3, 32, bias=True),
2191+
torch.rand(32, 3),
2192+
kind_in_graph="aten::linear")
2193+
self._test_output_bf16(
2194+
LinearSwish_v1(3, 32, bias=True),
2195+
torch.rand(32, 3),
2196+
kind_in_graph="ipex_prepack::linear_swish_run",
2197+
prec=5e-3)
2198+
self._test_output(
2199+
LinearSwish_v1(3, 32, bias=False),
2200+
torch.rand(32, 3),
2201+
kind_in_graph="aten::linear")
2202+
self._test_output_bf16(
2203+
LinearSwish_v1(3, 32, bias=False),
2204+
torch.rand(32, 3),
2205+
kind_in_graph="ipex_prepack::linear_swish_run",
2206+
prec=5e-3)
2207+
self._test_output(
2208+
LinearSwish(3, 32, bias=True),
2209+
torch.rand(32, 3),
2210+
kind_in_graph="aten::linear")
2211+
self._test_output_bf16(
2212+
LinearSwish(3, 32, bias=True),
2213+
torch.rand(32, 3),
2214+
kind_in_graph="ipex_prepack::linear_swish_run",
2215+
prec=5e-3)
2216+
self._test_output(
2217+
LinearSwish(3, 32, bias=False),
2218+
torch.rand(32, 3),
2219+
kind_in_graph="aten::linear")
2220+
self._test_output_bf16(
2221+
LinearSwish(3, 32, bias=False),
2222+
torch.rand(32, 3),
2223+
kind_in_graph="ipex_prepack::linear_swish_run", prec=5e-3)
2224+
2225+
def test_output_linear_sigmoid(self):
2226+
self._test_output(
2227+
LinearSigmoid(3, 32, bias=True),
2228+
torch.rand(32, 3),
2229+
kind_in_graph="aten::linear")
2230+
self._test_output_bf16(
2231+
LinearSigmoid(3, 32, bias=True),
2232+
torch.rand(32, 3),
2233+
kind_in_graph="ipex_prepack::linear_sigmoid_run",
2234+
prec=5e-3)
2235+
self._test_output(
2236+
LinearSigmoid(3, 32, bias=False),
2237+
torch.rand(32, 3),
2238+
kind_in_graph="aten::linear")
2239+
self._test_output_bf16(
2240+
LinearSigmoid(3, 32, bias=False),
2241+
torch.rand(32, 3),
2242+
kind_in_graph="ipex_prepack::linear_sigmoid_run",
2243+
prec=5e-3)
21552244

21562245
def test_channel_shuffle(self):
21572246
self._test_output(

0 commit comments

Comments
 (0)