Skip to content

Commit 608cfdd

Browse files
committed
perf: replace mask prefix scan with CUB segmented exclusive scan
1 parent 4c318ef commit 608cfdd

File tree

3 files changed

+128
-56
lines changed

3 files changed

+128
-56
lines changed

mlx/backend/cuda/indexing.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -469,20 +469,14 @@ void MaskedScatter::eval_gpu(const std::vector<array>& inputs, array& out) {
469469
scatter_offsets.set_data(cu::malloc_async(scatter_offsets.nbytes(), encoder));
470470
encoder.add_temporary(scatter_offsets);
471471

472-
scan_gpu_inplace(
473-
mask_flat,
474-
scatter_offsets,
475-
Scan::Sum,
476-
/* axis= */ 1,
477-
/* reverse= */ false,
478-
/* inclusive= */ false,
479-
s);
480-
481472
const size_t batch_count = mask.shape(0);
482473
const size_t mask_batch_size = mask_flat.size() / batch_count;
483474
const size_t src_batch_size = src.size() / src.shape(0);
484475
bool large = total > INT32_MAX || src.size() > INT32_MAX;
485476

477+
segmented_exclusive_mask_scan_gpu(
478+
mask_flat, scatter_offsets, static_cast<int64_t>(mask_batch_size), s);
479+
486480
std::string module_name =
487481
fmt::format("masked_scatter_assign_{}", dtype_to_string(out.dtype()));
488482
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {

mlx/backend/cuda/scan.cu

Lines changed: 122 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,18 @@
55
#include "mlx/backend/cuda/kernel_utils.cuh"
66
#include "mlx/backend/cuda/reduce/reduce_ops.cuh"
77
#include "mlx/backend/cuda/scan.h"
8+
#include "mlx/backend/cuda/utils.h"
89
#include "mlx/backend/gpu/copy.h"
910
#include "mlx/dtype_utils.h"
1011
#include "mlx/primitives.h"
1112

1213
#include <cooperative_groups.h>
1314
#include <cooperative_groups/scan.h>
1415
#include <nvtx3/nvtx3.hpp>
16+
#include <thrust/iterator/counting_iterator.h>
17+
#include <thrust/iterator/transform_iterator.h>
18+
#include <cub/device/device_scan.cuh>
19+
#include <cuda/std/functional>
1520

1621
#include <cassert>
1722

@@ -363,38 +368,139 @@ constexpr bool supports_scan_op() {
363368
}
364369
}
365370

366-
void scan_gpu_inplace(
367-
array in,
371+
namespace {
372+
373+
struct BoolToInt32 {
374+
__host__ __device__ int32_t operator()(bool v) const {
375+
return static_cast<int32_t>(v);
376+
}
377+
};
378+
379+
template <typename IdxT>
380+
struct MaskSegmentKey {
381+
IdxT segment_size;
382+
383+
__host__ __device__ IdxT operator()(IdxT i) const {
384+
return i / segment_size;
385+
}
386+
};
387+
388+
} // namespace
389+
390+
void segmented_exclusive_mask_scan_gpu(
391+
const array& in,
368392
array& out,
369-
Scan::ReduceType reduce_type,
370-
int axis,
371-
bool reverse,
372-
bool inclusive,
393+
int64_t segment_size,
373394
const Stream& s) {
395+
if (segment_size <= 0) {
396+
throw std::runtime_error("segment_size must be positive.");
397+
}
398+
374399
auto& encoder = cu::get_command_encoder(s);
400+
encoder.set_input_array(in);
401+
encoder.set_output_array(out);
402+
403+
using CubIdx = int64_t;
404+
auto count_iter = thrust::counting_iterator<CubIdx>(0);
405+
auto key_iter = thrust::make_transform_iterator(
406+
count_iter, MaskSegmentKey<CubIdx>{static_cast<CubIdx>(segment_size)});
407+
auto value_iter =
408+
thrust::make_transform_iterator(gpu_ptr<bool>(in), BoolToInt32{});
409+
410+
size_t workspace_size = 0;
411+
if (segment_size == static_cast<int64_t>(in.size())) {
412+
CHECK_CUDA_ERROR(
413+
cub::DeviceScan::ExclusiveSum(
414+
nullptr,
415+
workspace_size,
416+
value_iter,
417+
gpu_ptr<int32_t>(out),
418+
static_cast<CubIdx>(in.size()),
419+
encoder.stream()));
420+
421+
void* workspace = allocate_workspace(encoder, workspace_size);
422+
auto capture = encoder.capture_context();
423+
CHECK_CUDA_ERROR(
424+
cub::DeviceScan::ExclusiveSum(
425+
workspace,
426+
workspace_size,
427+
value_iter,
428+
gpu_ptr<int32_t>(out),
429+
static_cast<CubIdx>(in.size()),
430+
encoder.stream()));
431+
} else {
432+
CHECK_CUDA_ERROR(
433+
cub::DeviceScan::ExclusiveSumByKey(
434+
nullptr,
435+
workspace_size,
436+
key_iter,
437+
value_iter,
438+
gpu_ptr<int32_t>(out),
439+
static_cast<CubIdx>(in.size()),
440+
cuda::std::equal_to<>{},
441+
encoder.stream()));
442+
443+
void* workspace = allocate_workspace(encoder, workspace_size);
444+
auto capture = encoder.capture_context();
445+
CHECK_CUDA_ERROR(
446+
cub::DeviceScan::ExclusiveSumByKey(
447+
workspace,
448+
workspace_size,
449+
key_iter,
450+
value_iter,
451+
gpu_ptr<int32_t>(out),
452+
static_cast<CubIdx>(in.size()),
453+
cuda::std::equal_to<>{},
454+
encoder.stream()));
455+
}
456+
return;
457+
}
458+
459+
void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
460+
nvtx3::scoped_range r("Scan::eval_gpu");
461+
assert(inputs.size() == 1);
462+
auto in = inputs[0];
463+
auto& s = stream();
464+
auto& encoder = cu::get_command_encoder(s);
465+
466+
if (in.flags().contiguous && in.strides()[axis_] != 0) {
467+
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
468+
out.copy_shared_buffer(in);
469+
} else {
470+
out.set_data(
471+
cu::malloc_async(in.data_size() * out.itemsize(), encoder),
472+
in.data_size(),
473+
in.strides(),
474+
in.flags());
475+
}
476+
} else {
477+
in = contiguous_copy_gpu(in, s);
478+
out.copy_shared_buffer(in);
479+
}
480+
375481
constexpr int N_READS = 4;
376-
int32_t axis_size = in.shape(axis);
377-
bool contiguous = in.strides()[axis] == 1;
482+
int32_t axis_size = in.shape(axis_);
483+
bool contiguous = in.strides()[axis_] == 1;
378484

379485
encoder.set_input_array(in);
380486
encoder.set_output_array(out);
381487

382488
dispatch_all_types(in.dtype(), [&](auto type_tag) {
383489
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
384-
dispatch_scan_ops(reduce_type, [&](auto scan_op_tag) {
490+
dispatch_scan_ops(reduce_type_, [&](auto scan_op_tag) {
385491
using Op = MLX_GET_TYPE(scan_op_tag);
386492
if constexpr (supports_scan_op<Op, T>()) {
387493
using U = typename cu::ScanResult<Op, T>::type;
388-
dispatch_bool(inclusive, [&](auto inclusive_tag) {
389-
dispatch_bool(reverse, [&](auto reverse_tag) {
494+
dispatch_bool(inclusive_, [&](auto inclusive) {
495+
dispatch_bool(reverse_, [&](auto reverse) {
390496
if (contiguous) {
391497
auto kernel = cu::contiguous_scan<
392498
T,
393499
U,
394500
Op,
395501
N_READS,
396-
inclusive_tag.value,
397-
reverse_tag.value>;
502+
inclusive.value,
503+
reverse.value>;
398504
int block_dim = cuda::ceil_div(axis_size, N_READS);
399505
block_dim = cuda::ceil_div(block_dim, WARP_SIZE) * WARP_SIZE;
400506
block_dim = std::min(block_dim, WARP_SIZE * WARP_SIZE);
@@ -415,9 +521,9 @@ void scan_gpu_inplace(
415521
N_READS,
416522
BM,
417523
BN,
418-
inclusive_tag.value,
419-
reverse_tag.value>;
420-
int64_t stride = in.strides()[axis];
524+
inclusive.value,
525+
reverse.value>;
526+
int64_t stride = in.strides()[axis_];
421527
int64_t stride_blocks = cuda::ceil_div(stride, BN);
422528
dim3 num_blocks = get_2d_grid_dims(
423529
in.shape(), in.strides(), axis_size * stride);
@@ -451,29 +557,4 @@ void scan_gpu_inplace(
451557
});
452558
}
453559

454-
void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
455-
nvtx3::scoped_range r("Scan::eval_gpu");
456-
assert(inputs.size() == 1);
457-
auto in = inputs[0];
458-
auto& s = stream();
459-
auto& encoder = cu::get_command_encoder(s);
460-
461-
if (in.flags().contiguous && in.strides()[axis_] != 0) {
462-
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
463-
out.copy_shared_buffer(in);
464-
} else {
465-
out.set_data(
466-
cu::malloc_async(in.data_size() * out.itemsize(), encoder),
467-
in.data_size(),
468-
in.strides(),
469-
in.flags());
470-
}
471-
} else {
472-
in = contiguous_copy_gpu(in, s);
473-
out.copy_shared_buffer(in);
474-
}
475-
476-
scan_gpu_inplace(in, out, reduce_type_, axis_, reverse_, inclusive_, s);
477-
}
478-
479560
} // namespace mlx::core

mlx/backend/cuda/scan.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,10 @@
88

99
namespace mlx::core {
1010

11-
void scan_gpu_inplace(
12-
array in,
11+
void segmented_exclusive_mask_scan_gpu(
12+
const array& in,
1313
array& out,
14-
Scan::ReduceType reduce_type,
15-
int axis,
16-
bool reverse,
17-
bool inclusive,
14+
int64_t segment_size,
1815
const Stream& s);
1916

2017
} // namespace mlx::core

0 commit comments

Comments
 (0)