Skip to content

Commit e70c1fc

Browse files
committed
perf: fuse masked scatter copy and assign
1 parent 608cfdd commit e70c1fc

File tree

2 files changed

+47
-30
lines changed

2 files changed

+47
-30
lines changed

mlx/backend/cuda/device/scatter.cuh

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,19 @@ __global__ void scatter(
6565
Op{}(out + out_idx, upd[upd_loc]);
6666
}
6767

68-
template <typename T, bool SrcContiguous, typename IdxT>
69-
__global__ void masked_scatter_assign(
68+
template <typename T, bool SrcContiguous, bool DstContiguous, typename IdxT>
69+
__global__ void masked_scatter_fused(
70+
const T* dst,
7071
const bool* mask,
7172
const int32_t* scatter_offsets,
7273
const T* src,
7374
T* out,
7475
IdxT size,
7576
IdxT src_batch_size,
7677
IdxT mask_batch_size,
78+
const __grid_constant__ Shape dst_shape,
79+
const __grid_constant__ Strides dst_strides,
80+
int32_t dst_ndim,
7781
const __grid_constant__ Shape src_shape,
7882
const __grid_constant__ Strides src_strides,
7983
int32_t src_ndim) {
@@ -82,25 +86,32 @@ __global__ void masked_scatter_assign(
8286
return;
8387
}
8488

85-
if (!mask[index]) {
86-
return;
89+
T dst_val;
90+
if constexpr (DstContiguous) {
91+
dst_val = dst[index];
92+
} else {
93+
IdxT dst_loc =
94+
elem_to_loc(index, dst_shape.data(), dst_strides.data(), dst_ndim);
95+
dst_val = dst[dst_loc];
8796
}
8897

89-
IdxT src_index = static_cast<IdxT>(scatter_offsets[index]);
90-
if (src_index >= src_batch_size) {
91-
// Match Metal backend behavior by skipping out-of-range source reads.
92-
return;
98+
if (mask[index]) {
99+
IdxT src_index = static_cast<IdxT>(scatter_offsets[index]);
100+
if (src_index < src_batch_size) {
101+
IdxT batch_idx = index / mask_batch_size;
102+
if constexpr (SrcContiguous) {
103+
out[index] = src[batch_idx * src_batch_size + src_index];
104+
} else {
105+
IdxT src_elem = batch_idx * src_batch_size + src_index;
106+
IdxT src_loc = elem_to_loc(
107+
src_elem, src_shape.data(), src_strides.data(), src_ndim);
108+
out[index] = src[src_loc];
109+
}
110+
return;
111+
}
93112
}
94113

95-
IdxT batch_idx = index / mask_batch_size;
96-
if constexpr (SrcContiguous) {
97-
out[index] = src[batch_idx * src_batch_size + src_index];
98-
} else {
99-
IdxT src_elem = batch_idx * src_batch_size + src_index;
100-
IdxT src_loc =
101-
elem_to_loc(src_elem, src_shape.data(), src_strides.data(), src_ndim);
102-
out[index] = src[src_loc];
103-
}
114+
out[index] = dst_val;
104115
}
105116

106117
} // namespace mlx::core::cu

mlx/backend/cuda/indexing.cpp

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -448,10 +448,7 @@ void MaskedScatter::eval_gpu(const std::vector<array>& inputs, array& out) {
448448
auto& encoder = cu::get_command_encoder(s);
449449

450450
const size_t total = mask.size();
451-
const CopyType copy_type = (total == 1)
452-
? CopyType::Scalar
453-
: (dst.flags().row_contiguous ? CopyType::Vector : CopyType::General);
454-
copy_gpu(dst, out, copy_type, s);
451+
out.set_data(cu::malloc_async(out.nbytes(), encoder));
455452
if (total == 0) {
456453
return;
457454
}
@@ -478,23 +475,27 @@ void MaskedScatter::eval_gpu(const std::vector<array>& inputs, array& out) {
478475
mask_flat, scatter_offsets, static_cast<int64_t>(mask_batch_size), s);
479476

480477
std::string module_name =
481-
fmt::format("masked_scatter_assign_{}", dtype_to_string(out.dtype()));
478+
fmt::format("masked_scatter_fused_{}", dtype_to_string(out.dtype()));
482479
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
483480
std::vector<std::string> kernel_names;
484481
for (int src_contiguous = 0; src_contiguous <= 1; ++src_contiguous) {
485-
for (int use_large = 0; use_large <= 1; ++use_large) {
486-
kernel_names.push_back(
487-
fmt::format(
488-
"mlx::core::cu::masked_scatter_assign<{}, {}, {}>",
489-
dtype_to_cuda_type(out.dtype()),
490-
src_contiguous ? "true" : "false",
491-
use_large ? "int64_t" : "int32_t"));
482+
for (int dst_contiguous = 0; dst_contiguous <= 1; ++dst_contiguous) {
483+
for (int use_large = 0; use_large <= 1; ++use_large) {
484+
kernel_names.push_back(
485+
fmt::format(
486+
"mlx::core::cu::masked_scatter_fused<{}, {}, {}, {}>",
487+
dtype_to_cuda_type(out.dtype()),
488+
src_contiguous ? "true" : "false",
489+
dst_contiguous ? "true" : "false",
490+
use_large ? "int64_t" : "int32_t"));
491+
}
492492
}
493493
}
494494
return std::make_tuple(false, jit_source_scatter, std::move(kernel_names));
495495
});
496496

497497
cu::KernelArgs args;
498+
args.append(dst);
498499
args.append(mask_flat);
499500
args.append(scatter_offsets);
500501
args.append(src);
@@ -508,19 +509,24 @@ void MaskedScatter::eval_gpu(const std::vector<array>& inputs, array& out) {
508509
args.append<int32_t>(src_batch_size);
509510
args.append<int32_t>(mask_batch_size);
510511
}
512+
args.append_ndim(dst.shape());
513+
args.append_ndim(dst.strides());
514+
args.append<int32_t>(dst.ndim());
511515
args.append_ndim(src.shape());
512516
args.append_ndim(src.strides());
513517
args.append<int32_t>(src.ndim());
514518

519+
encoder.set_input_array(dst);
515520
encoder.set_input_array(mask_flat);
516521
encoder.set_input_array(scatter_offsets);
517522
encoder.set_input_array(src);
518523
encoder.set_output_array(out);
519524

520525
std::string kernel_name = fmt::format(
521-
"mlx::core::cu::masked_scatter_assign<{}, {}, {}>",
526+
"mlx::core::cu::masked_scatter_fused<{}, {}, {}, {}>",
522527
dtype_to_cuda_type(out.dtype()),
523528
src.flags().row_contiguous ? "true" : "false",
529+
dst.flags().row_contiguous ? "true" : "false",
524530
large ? "int64_t" : "int32_t");
525531
auto kernel = mod.get_kernel(kernel_name);
526532
auto [num_blocks, block_dims] = get_launch_args(mask_flat, large);

0 commit comments

Comments
 (0)