Skip to content

Commit b9b1fd9

Browse files
ZhiweiYan-96guangyey
authored andcommitted
[Intel GPU] qlinear.pointwise with mixed dtype support (pytorch#136753)
# Motivation This PR is aimed to add mixed data type(AMP) support for `qlinear_pointwise` op. With current PR, we allow `qlinear` kernels output Tensor that is BF16, rather than FP32/INT8. # UT verification ```bash DNNL_VERBOSE=1 python test/inductor/test_mkldnn_pattern_matcher.py -v \ -k test_qlinear_int8_mixed_bf16_xpu \ -k test_qlinear_relu_int8_mixed_bf16_xpu \ -k test_qlinear_add_int8_mixed_bf16_xpu ``` # Runtime exemplification ```bash #qlinear+bf16 output onednn_verbose,primitive,exec,gpu:0,matmul,ocl:gemm_with_po:any,undef,src_s8::blocked:ab::f0 wei_s8::blocked:ab::f0 bia_bf16::blocked:ab::f0_mask2 dst_bf16::blocked:ab::f0,attr-scratchpad:user attr-scales:src0:0:f32+dst:0:f32+wei:2:f32 attr-zero-points:src0:0:s32,,4x4:4x4,0.0698242 # qlinear_add + bf16 output onednn_verbose,primitive,exec,gpu:0,matmul,ocl:gemm_with_po:any,undef,src_s8::blocked:ab::f0 wei_s8::blocked:ab::f0 bia_bf16::blocked:ab::f0_mask2 dst_bf16::blocked:ab::f0,attr-scratchpad:user attr-scales:src0:0:f32+dst:0:f32+wei:2:f32 attr-zero-points:src0:0:s32 attr-post-ops:eltwise_linear:1:-0.677141+sum:0.0132773,,4x4:4x4,0.0419922 # qlinear_add_relu + bf16 output onednn_verbose,primitive,exec,gpu:0,matmul,ocl:gemm_with_po:any,undef,src_s8::blocked:ab::f0 wei_s8::blocked:ab::f0 bia_bf16::blocked:ab::f0_mask2 dst_bf16::blocked:ab::f0,attr-scratchpad:user attr-scales:src0:0:f32+dst:0:f32+wei:2:f32 attr-zero-points:src0:0:s32 attr-post-ops:eltwise_linear:1:0.533096+sum:0.00416481+eltwise_relu,,4x4:4x4,0.0759277 ``` As shown in the oneDNN verbose, the attribute `dst_bf16::blocked:ab::f0` demonstrate that we could successfully output a bf16 tensor in int8 gemm. Pull Request resolved: pytorch#136753 Approved by: https://github.com/EikanWang, https://github.com/guangyey, https://github.com/desertfire, https://github.com/jerryzh168 ghstack dependencies: pytorch#133307, pytorch#135189, pytorch#135337, pytorch#135465 Co-authored-by: guangyey <[email protected]>
1 parent 075b91b commit b9b1fd9

File tree

2 files changed

+315
-30
lines changed

2 files changed

+315
-30
lines changed

aten/src/ATen/native/mkldnn/xpu/qlinear.cpp

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,18 @@ using namespace at::native::onednn;
77

88
namespace at::native::xpu {
99

10+
static inline c10::ScalarType qlinear_decide_out_dtype(
11+
const at::Tensor& act,
12+
const std::optional<c10::ScalarType> output_dtype) {
13+
bool fp32_output = output_dtype.has_value() && (output_dtype == c10::kFloat);
14+
bool bfloat16_output =
15+
output_dtype.has_value() && (output_dtype == c10::kBFloat16);
16+
auto dst_dtype = fp32_output
17+
? c10::kFloat
18+
: (bfloat16_output ? c10::kBFloat16 : act.scalar_type());
19+
return dst_dtype;
20+
}
21+
1022
Tensor q_linear_pointwise(
1123
Tensor act,
1224
double act_scale,
@@ -37,9 +49,9 @@ Tensor q_linear_pointwise(
3749

3850
std::vector<int64_t> src_dims = {M, K};
3951
std::vector<int64_t> dst_dims = {M, N};
40-
auto out_dtype =
41-
output_dtype.has_value() ? output_dtype.value() : act.scalar_type();
42-
Tensor qout = at::empty(dst_dims, act.options().dtype(out_dtype));
52+
53+
auto dst_dtype = qlinear_decide_out_dtype(act, output_dtype);
54+
Tensor qout = at::empty(dst_dims, act.options().dtype(dst_dtype));
4355

4456
quantized_matmul(
4557
act.contiguous(),
@@ -96,9 +108,9 @@ Tensor q_linear_pointwise_tensor(
96108

97109
std::vector<int64_t> src_dims = {M, K};
98110
std::vector<int64_t> dst_dims = {M, N};
99-
auto out_dtype =
100-
output_dtype.has_value() ? output_dtype.value() : act.scalar_type();
101-
Tensor qout = at::empty(dst_dims, act.options().dtype(out_dtype));
111+
112+
auto dst_dtype = qlinear_decide_out_dtype(act, output_dtype);
113+
Tensor qout = at::empty(dst_dims, act.options().dtype(dst_dtype));
102114

103115
quantized_matmul(
104116
act.contiguous(),
@@ -159,9 +171,8 @@ Tensor q_linear_pointwise_binary(
159171

160172
std::vector<int64_t> src_dims = {M, K};
161173
std::vector<int64_t> dst_dims = {M, N};
162-
auto out_dtype =
163-
output_dtype.has_value() ? output_dtype.value() : act.scalar_type();
164-
Tensor qout = at::empty(dst_dims, act.options().dtype(out_dtype));
174+
auto dst_dtype = qlinear_decide_out_dtype(act, output_dtype);
175+
Tensor qout = at::empty(dst_dims, act.options().dtype(dst_dtype));
165176

166177
quantized_matmul(
167178
act.contiguous(),
@@ -222,9 +233,8 @@ Tensor q_linear_pointwise_binary_tensor(
222233

223234
std::vector<int64_t> src_dims = {M, K};
224235
std::vector<int64_t> dst_dims = {M, N};
225-
auto out_dtype =
226-
output_dtype.has_value() ? output_dtype.value() : act.scalar_type();
227-
Tensor qout = at::empty(dst_dims, act.options().dtype(out_dtype));
236+
auto dst_dtype = qlinear_decide_out_dtype(act, output_dtype);
237+
Tensor qout = at::empty(dst_dims, act.options().dtype(dst_dtype));
228238

229239
quantized_matmul(
230240
act.contiguous(),

0 commit comments

Comments
 (0)