Skip to content

Commit 6ad4be0

Browse files
chunyuanwEikanWang
authored andcommitted
enable the fusion of linear and GeLU
1 parent d759c85 commit 6ad4be0

File tree

6 files changed

+72
-7
lines changed

6 files changed

+72
-7
lines changed

tests/cpu/test_jit.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,16 @@ def __init__(self, in_channels, out_channels, **kwargs):
149149

150150
def forward(self, x):
151151
return F.relu(self.linear(x), inplace=True)
152+
153+
class LinearGelu(nn.Module):
154+
def __init__(self, in_channels, out_channels, **kwargs):
155+
super(LinearGelu, self).__init__()
156+
seed = 2018
157+
torch.manual_seed(seed)
158+
self.linear = nn.Linear(in_channels, out_channels, **kwargs)
159+
160+
def forward(self, x):
161+
return F.gelu(self.linear(x))
152162

153163
class ConvSumInDiffBlock(nn.Module):
154164
def __init__(self, dim, in_channels, out_channels, **kwargs):
@@ -544,6 +554,27 @@ def test_output_linear_relu(self):
544554
kind_in_graph="ipex::linear_relu")
545555

546556

557+
def test_output_linear_gelu(self):
558+
self._test_output(
559+
LinearGelu(3, 32, bias=True),
560+
torch.rand(32, 3),
561+
kind_in_graph="ipex::linear_gelu")
562+
self._test_output_bf16(
563+
LinearGelu(3, 32, bias=True),
564+
torch.rand(32, 3),
565+
kind_in_graph="ipex::linear_gelu",
566+
prec=5e-3)
567+
self._test_output(
568+
LinearGelu(3, 32, bias=False),
569+
torch.rand(32, 3),
570+
kind_in_graph="ipex::linear_gelu")
571+
self._test_output_bf16(
572+
LinearGelu(3, 32, bias=False),
573+
torch.rand(32, 3),
574+
kind_in_graph="ipex::linear_gelu",
575+
prec=5e-3)
576+
577+
547578
def test_channel_shuffle(self):
548579
self._test_output(
549580
ChannelShuffle(10, 16, 50, 50, 4),

torch_ipex/csrc/cpu/FusionOPs.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -386,12 +386,13 @@ at::Tensor& AtenIpexJITDev::dil_convolution_sum_relu(
386386
"Convolution_Sum_Relu");
387387
}
388388

389-
at::Tensor AtenIpexJITDev::dil_linear_fuse_relu(
389+
at::Tensor AtenIpexJITDev::dil_linear_fuse_eltwise(
390390
const at::Tensor& self,
391391
const at::Tensor& weight,
392-
const at::Tensor& bias) {
392+
const at::Tensor& bias,
393+
const dil::attr_t& attr) {
393394
#if defined(IPEX_PROFILE_OP)
394-
RECORD_FUNCTION("AtenIpexJITDev::dil_linear_fuse_relu", std::vector<c10::IValue>({self, weight, bias}), torch::autograd::Node::peek_at_next_sequence_nr());
395+
RECORD_FUNCTION("AtenIpexJITDev::dil_linear_fuse_eltwise", std::vector<c10::IValue>({self, weight, bias}), torch::autograd::Node::peek_at_next_sequence_nr());
395396
#endif
396397
IPEX_CHECK(self.dim() >= 2,
397398
"dil_linear: input needs to has dim at least 2, input dim ", self.dim());
@@ -413,7 +414,7 @@ at::Tensor AtenIpexJITDev::dil_linear_fuse_relu(
413414
b = try_gen_dil_tensor(bias_contiguous);
414415
}
415416

416-
dil::tensor y = dbl::linear::linear_impl(x, w, b, /* dst_scale */ dil::scale_t(), dil::attr_t::fuse_relu());
417+
dil::tensor y = dbl::linear::linear_impl(x, w, b, /* dst_scale */ dil::scale_t(), attr);
417418

418419
auto input_size = self.sizes();
419420
std::vector<int64_t> output_size(input_size.begin(), input_size.end() - 1);

torch_ipex/csrc/cpu/FusionOPs.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ namespace ipex {
1818
// static auto conv3d_relu_sum = Symbol::fromQualString("ipex::conv3d_relu_sum");
1919
static auto conv2d_sum_relu = Symbol::fromQualString("ipex::conv2d_sum_relu");
2020
static auto linear_relu = Symbol::fromQualString("ipex::linear_relu");
21+
static auto linear_gelu = Symbol::fromQualString("ipex::linear_gelu");
2122

2223
// 3d ops
2324
static auto conv3d_relu = Symbol::fromQualString("ipex::conv3d_relu");
@@ -48,7 +49,7 @@ class AtenIpexJITDev {
4849

4950
static at::Tensor& dil_convolution_sum_relu( const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups, at::Tensor& accumu, at::Scalar alpha);
5051

51-
static at::Tensor dil_linear_fuse_relu(const at::Tensor& self, const at::Tensor& weight, const at::Tensor& bias);
52+
static at::Tensor dil_linear_fuse_eltwise(const at::Tensor& self, const at::Tensor& weight, const at::Tensor& bias, const dil::attr_t& attr);
5253

5354
};
5455

torch_ipex/csrc/cpu/dil/dil/attributes.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,15 @@ struct attr_t : public dnnl::primitive_attr {
6969
return attr;
7070
}
7171

72+
static attr_t fuse_gelu(float scale = 1.0, float alpha = 0.f,
73+
float beta = 0.f) {
74+
attr_t attr;
75+
post_ops po;
76+
po.append_eltwise(scale, algorithm::eltwise_gelu_tanh, alpha, beta);
77+
attr.set_post_ops(po);
78+
return attr;
79+
}
80+
7281
static attr_t fuse_elu(float scale = 1.0f, float alpha = 0.f, float beta = 1.0f) {
7382
attr_t attr;
7483
post_ops po;

torch_ipex/csrc/jit/fusion_pass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ OpFuser::RuleTab OpFuser::dnnlRules = {
288288
{{ipex::conv2d_sum, Symbol::fromQualString("aten::relu_")}, ipex::conv2d_sum_relu},
289289

290290
{{Symbol::fromQualString("torch_ipex::linear"), aten::relu}, ipex::linear_relu},
291+
{{Symbol::fromQualString("torch_ipex::linear"), aten::gelu}, ipex::linear_gelu},
291292
{{Symbol::fromQualString("torch_ipex::linear"), Symbol::fromQualString("aten::relu_")}, ipex::linear_relu},
292293

293294
// 3d ops

torch_ipex/csrc/jit/register_dnnl_jit_ops.cpp

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,10 +314,11 @@ RegisterOperators op({
314314
[] (const Node* node) ->Operation {
315315
if (torch_ipex::check_auto_dnnl()) {
316316
return [] (Stack& stack) {
317-
auto result = AtenIpexJITDev::dil_linear_fuse_relu(
317+
auto result = AtenIpexJITDev::dil_linear_fuse_eltwise(
318318
(std::move(peek(stack, 0, 3))).toTensor(),
319319
(std::move(peek(stack, 1, 3))).toTensor(),
320-
toOptionalTensor(std::move(peek(stack, 2, 3)))
320+
toOptionalTensor(std::move(peek(stack, 2, 3))),
321+
dil::attr_t::fuse_relu()
321322
);
322323
drop(stack, 3);
323324
pack(stack, std::move(result));
@@ -328,6 +329,27 @@ RegisterOperators op({
328329
}
329330
},
330331
aliasAnalysisFromSchema()
332+
),
333+
Operator(
334+
"ipex::linear_gelu(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor",
335+
[] (const Node* node) ->Operation {
336+
if (torch_ipex::check_auto_dnnl()) {
337+
return [] (Stack& stack) {
338+
auto result = AtenIpexJITDev::dil_linear_fuse_eltwise(
339+
(std::move(peek(stack, 0, 3))).toTensor(),
340+
(std::move(peek(stack, 1, 3))).toTensor(),
341+
toOptionalTensor(std::move(peek(stack, 2, 3))),
342+
dil::attr_t::fuse_gelu()
343+
);
344+
drop(stack, 3);
345+
pack(stack, std::move(result));
346+
return 0;
347+
};
348+
} else {
349+
TORCH_CHECK(false, "PyTorch native path not support linear gelu fusion now");
350+
}
351+
},
352+
aliasAnalysisFromSchema()
331353
)
332354
});
333355
}

0 commit comments

Comments
 (0)