@@ -210,6 +210,63 @@ Tensor& opt_mul_out(
210
210
return out;
211
211
}
212
212
213
+ Tensor& opt_mul_scalar_out (
214
+ KernelRuntimeContext& ctx,
215
+ const Tensor& a,
216
+ const Scalar& b,
217
+ Tensor& out) {
218
+ (void )ctx;
219
+
220
+ ScalarType a_type = a.scalar_type ();
221
+ ScalarType common_type =
222
+ utils::promote_type_with_scalar (a_type, b, /* half_to_float*/ false );
223
+ ScalarType out_type = out.scalar_type ();
224
+
225
+ ET_CHECK (common_type == out_type);
226
+
227
+ if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) {
228
+ common_type = ScalarType::Float;
229
+ }
230
+
231
+ // Resize for dynamic shape
232
+ auto error = resize_tensor (out, a.sizes ());
233
+ ET_CHECK_MSG (error == Error::Ok, " Failed to resize output tensor." );
234
+
235
+ if (a_type == common_type && a_type == out_type &&
236
+ a_type != ScalarType::Half && a_type != ScalarType::BFloat16) {
237
+ ET_SWITCH_REALB_TYPES (a_type, ctx, " mul.Scalar_out" , CTYPE, [&]() {
238
+ CTYPE b_casted = utils::scalar_to<CTYPE>(b);
239
+
240
+ using Vec = at::vec::Vectorized<CTYPE>;
241
+ at::vec::map<CTYPE>(
242
+ [b_casted](Vec x) { return x * Vec (b_casted); },
243
+ out.mutable_data_ptr <CTYPE>(),
244
+ a.const_data_ptr <CTYPE>(),
245
+ out.numel ());
246
+ });
247
+ } else {
248
+ ET_SWITCH_REALHBBF16_TYPES (a_type, ctx, " mul.Scalar_out" , CTYPE_A, [&]() {
249
+ ET_SWITCH_REALB_TYPES (
250
+ common_type, ctx, " mul.Scalar_out" , CTYPE_IN, [&]() {
251
+ ET_SWITCH_REALHBBF16_TYPES (
252
+ out_type, ctx, " mul.Scalar_out" , CTYPE_OUT, [&]() {
253
+ CTYPE_IN b_casted = utils::scalar_to<CTYPE_IN>(b);
254
+
255
+ const size_t n = a.numel ();
256
+ const CTYPE_A* a_data = a.const_data_ptr <CTYPE_A>();
257
+ CTYPE_OUT* out_data = out.mutable_data_ptr <CTYPE_OUT>();
258
+ for (auto i = 0 ; i < n; ++i) {
259
+ out_data[i] = static_cast <CTYPE_OUT>(
260
+ static_cast <CTYPE_IN>(a_data[i]) * b_casted);
261
+ }
262
+ });
263
+ });
264
+ });
265
+ }
266
+
267
+ return out;
268
+ }
269
+
213
270
} // namespace native
214
271
} // namespace executor
215
272
} // namespace torch
0 commit comments