@@ -448,10 +448,7 @@ void MaskedScatter::eval_gpu(const std::vector<array>& inputs, array& out) {
448448 auto & encoder = cu::get_command_encoder (s);
449449
450450 const size_t total = mask.size ();
451- const CopyType copy_type = (total == 1 )
452- ? CopyType::Scalar
453- : (dst.flags ().row_contiguous ? CopyType::Vector : CopyType::General);
454- copy_gpu (dst, out, copy_type, s);
451+ out.set_data (cu::malloc_async (out.nbytes (), encoder));
455452 if (total == 0 ) {
456453 return ;
457454 }
@@ -478,23 +475,27 @@ void MaskedScatter::eval_gpu(const std::vector<array>& inputs, array& out) {
478475 mask_flat, scatter_offsets, static_cast <int64_t >(mask_batch_size), s);
479476
480477 std::string module_name =
481- fmt::format (" masked_scatter_assign_ {}" , dtype_to_string (out.dtype ()));
478+ fmt::format (" masked_scatter_fused_ {}" , dtype_to_string (out.dtype ()));
482479 cu::JitModule& mod = cu::get_jit_module (s.device , module_name, [&]() {
483480 std::vector<std::string> kernel_names;
484481 for (int src_contiguous = 0 ; src_contiguous <= 1 ; ++src_contiguous) {
485- for (int use_large = 0 ; use_large <= 1 ; ++use_large) {
486- kernel_names.push_back (
487- fmt::format (
488- " mlx::core::cu::masked_scatter_assign<{}, {}, {}>" ,
489- dtype_to_cuda_type (out.dtype ()),
490- src_contiguous ? " true" : " false" ,
491- use_large ? " int64_t" : " int32_t" ));
482+ for (int dst_contiguous = 0 ; dst_contiguous <= 1 ; ++dst_contiguous) {
483+ for (int use_large = 0 ; use_large <= 1 ; ++use_large) {
484+ kernel_names.push_back (
485+ fmt::format (
486+ " mlx::core::cu::masked_scatter_fused<{}, {}, {}, {}>" ,
487+ dtype_to_cuda_type (out.dtype ()),
488+ src_contiguous ? " true" : " false" ,
489+ dst_contiguous ? " true" : " false" ,
490+ use_large ? " int64_t" : " int32_t" ));
491+ }
492492 }
493493 }
494494 return std::make_tuple (false , jit_source_scatter, std::move (kernel_names));
495495 });
496496
497497 cu::KernelArgs args;
498+ args.append (dst);
498499 args.append (mask_flat);
499500 args.append (scatter_offsets);
500501 args.append (src);
@@ -508,19 +509,24 @@ void MaskedScatter::eval_gpu(const std::vector<array>& inputs, array& out) {
508509 args.append <int32_t >(src_batch_size);
509510 args.append <int32_t >(mask_batch_size);
510511 }
512+ args.append_ndim (dst.shape ());
513+ args.append_ndim (dst.strides ());
514+ args.append <int32_t >(dst.ndim ());
511515 args.append_ndim (src.shape ());
512516 args.append_ndim (src.strides ());
513517 args.append <int32_t >(src.ndim ());
514518
519+ encoder.set_input_array (dst);
515520 encoder.set_input_array (mask_flat);
516521 encoder.set_input_array (scatter_offsets);
517522 encoder.set_input_array (src);
518523 encoder.set_output_array (out);
519524
520525 std::string kernel_name = fmt::format (
521- " mlx::core::cu::masked_scatter_assign< {}, {}, {}>" ,
526+ " mlx::core::cu::masked_scatter_fused<{}, {}, {}, {}>" ,
522527 dtype_to_cuda_type (out.dtype ()),
523528 src.flags ().row_contiguous ? " true" : " false" ,
529+ dst.flags ().row_contiguous ? " true" : " false" ,
524530 large ? " int64_t" : " int32_t" );
525531 auto kernel = mod.get_kernel (kernel_name);
526532 auto [num_blocks, block_dims] = get_launch_args (mask_flat, large);
0 commit comments