@@ -7,6 +7,18 @@ using namespace at::native::onednn;
7
7
8
8
namespace at ::native::xpu {
9
9
10
+ static inline c10::ScalarType qlinear_decide_out_dtype (
11
+ const at::Tensor& act,
12
+ const std::optional<c10::ScalarType> output_dtype) {
13
+ bool fp32_output = output_dtype.has_value () && (output_dtype == c10::kFloat );
14
+ bool bfloat16_output =
15
+ output_dtype.has_value () && (output_dtype == c10::kBFloat16 );
16
+ auto dst_dtype = fp32_output
17
+ ? c10::kFloat
18
+ : (bfloat16_output ? c10::kBFloat16 : act.scalar_type ());
19
+ return dst_dtype;
20
+ }
21
+
10
22
Tensor q_linear_pointwise (
11
23
Tensor act,
12
24
double act_scale,
@@ -37,9 +49,9 @@ Tensor q_linear_pointwise(
37
49
38
50
std::vector<int64_t > src_dims = {M, K};
39
51
std::vector<int64_t > dst_dims = {M, N};
40
- auto out_dtype =
41
- output_dtype. has_value () ? output_dtype. value () : act. scalar_type ( );
42
- Tensor qout = at::empty (dst_dims, act.options ().dtype (out_dtype ));
52
+
53
+ auto dst_dtype = qlinear_decide_out_dtype (act, output_dtype);
54
+ Tensor qout = at::empty (dst_dims, act.options ().dtype (dst_dtype ));
43
55
44
56
quantized_matmul (
45
57
act.contiguous (),
@@ -96,9 +108,9 @@ Tensor q_linear_pointwise_tensor(
96
108
97
109
std::vector<int64_t > src_dims = {M, K};
98
110
std::vector<int64_t > dst_dims = {M, N};
99
- auto out_dtype =
100
- output_dtype. has_value () ? output_dtype. value () : act. scalar_type ( );
101
- Tensor qout = at::empty (dst_dims, act.options ().dtype (out_dtype ));
111
+
112
+ auto dst_dtype = qlinear_decide_out_dtype (act, output_dtype);
113
+ Tensor qout = at::empty (dst_dims, act.options ().dtype (dst_dtype ));
102
114
103
115
quantized_matmul (
104
116
act.contiguous (),
@@ -159,9 +171,8 @@ Tensor q_linear_pointwise_binary(
159
171
160
172
std::vector<int64_t > src_dims = {M, K};
161
173
std::vector<int64_t > dst_dims = {M, N};
162
- auto out_dtype =
163
- output_dtype.has_value () ? output_dtype.value () : act.scalar_type ();
164
- Tensor qout = at::empty (dst_dims, act.options ().dtype (out_dtype));
174
+ auto dst_dtype = qlinear_decide_out_dtype (act, output_dtype);
175
+ Tensor qout = at::empty (dst_dims, act.options ().dtype (dst_dtype));
165
176
166
177
quantized_matmul (
167
178
act.contiguous (),
@@ -222,9 +233,8 @@ Tensor q_linear_pointwise_binary_tensor(
222
233
223
234
std::vector<int64_t > src_dims = {M, K};
224
235
std::vector<int64_t > dst_dims = {M, N};
225
- auto out_dtype =
226
- output_dtype.has_value () ? output_dtype.value () : act.scalar_type ();
227
- Tensor qout = at::empty (dst_dims, act.options ().dtype (out_dtype));
236
+ auto dst_dtype = qlinear_decide_out_dtype (act, output_dtype);
237
+ Tensor qout = at::empty (dst_dims, act.options ().dtype (dst_dtype));
228
238
229
239
quantized_matmul (
230
240
act.contiguous (),
0 commit comments