Skip to content

Commit 6ac5df2

Browse files
authored
Revert "Delete opt_mul_scalar_out (#12145)" (#12321)
This triggered internal failures; kernels/optimized's tests don't build with Buck in OSS because they use ATen.
1 parent 93e9fcd commit 6ac5df2

File tree

2 files changed

+62
-0
lines changed

2 files changed

+62
-0
lines changed

kernels/optimized/cpu/op_mul.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,63 @@ Tensor& opt_mul_out(
210210
return out;
211211
}
212212

213+
Tensor& opt_mul_scalar_out(
214+
KernelRuntimeContext& ctx,
215+
const Tensor& a,
216+
const Scalar& b,
217+
Tensor& out) {
218+
(void)ctx;
219+
220+
ScalarType a_type = a.scalar_type();
221+
ScalarType common_type =
222+
utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false);
223+
ScalarType out_type = out.scalar_type();
224+
225+
ET_CHECK(common_type == out_type);
226+
227+
if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) {
228+
common_type = ScalarType::Float;
229+
}
230+
231+
// Resize for dynamic shape
232+
auto error = resize_tensor(out, a.sizes());
233+
ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
234+
235+
if (a_type == common_type && a_type == out_type &&
236+
a_type != ScalarType::Half && a_type != ScalarType::BFloat16) {
237+
ET_SWITCH_REALB_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE, [&]() {
238+
CTYPE b_casted = utils::scalar_to<CTYPE>(b);
239+
240+
using Vec = at::vec::Vectorized<CTYPE>;
241+
at::vec::map<CTYPE>(
242+
[b_casted](Vec x) { return x * Vec(b_casted); },
243+
out.mutable_data_ptr<CTYPE>(),
244+
a.const_data_ptr<CTYPE>(),
245+
out.numel());
246+
});
247+
} else {
248+
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE_A, [&]() {
249+
ET_SWITCH_REALB_TYPES(
250+
common_type, ctx, "mul.Scalar_out", CTYPE_IN, [&]() {
251+
ET_SWITCH_REALHBBF16_TYPES(
252+
out_type, ctx, "mul.Scalar_out", CTYPE_OUT, [&]() {
253+
CTYPE_IN b_casted = utils::scalar_to<CTYPE_IN>(b);
254+
255+
const size_t n = a.numel();
256+
const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
257+
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
258+
for (auto i = 0; i < n; ++i) {
259+
out_data[i] = static_cast<CTYPE_OUT>(
260+
static_cast<CTYPE_IN>(a_data[i]) * b_casted);
261+
}
262+
});
263+
});
264+
});
265+
}
266+
267+
return out;
268+
}
269+
213270
} // namespace native
214271
} // namespace executor
215272
} // namespace torch

kernels/optimized/optimized.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,11 @@
8282
- arg_meta: null
8383
kernel_name: torch::executor::opt_mul_out
8484

85+
- op: mul.Scalar_out
86+
kernels:
87+
- arg_meta: null
88+
kernel_name: torch::executor::opt_mul_scalar_out
89+
8590
- op: native_layer_norm.out
8691
kernels:
8792
- arg_meta: null

0 commit comments

Comments
 (0)