55#include " mlx/backend/cuda/kernel_utils.cuh"
66#include " mlx/backend/cuda/reduce/reduce_ops.cuh"
77#include " mlx/backend/cuda/scan.h"
8+ #include " mlx/backend/cuda/utils.h"
89#include " mlx/backend/gpu/copy.h"
910#include " mlx/dtype_utils.h"
1011#include " mlx/primitives.h"
1112
1213#include < cooperative_groups.h>
1314#include < cooperative_groups/scan.h>
1415#include < nvtx3/nvtx3.hpp>
16+ #include < thrust/iterator/counting_iterator.h>
17+ #include < thrust/iterator/transform_iterator.h>
18+ #include < cub/device/device_scan.cuh>
19+ #include < cuda/std/functional>
1520
1621#include < cassert>
1722
@@ -363,38 +368,139 @@ constexpr bool supports_scan_op() {
363368 }
364369}
365370
366- void scan_gpu_inplace (
367- array in,
371+ namespace {
372+
373+ struct BoolToInt32 {
374+ __host__ __device__ int32_t operator ()(bool v) const {
375+ return static_cast <int32_t >(v);
376+ }
377+ };
378+
379+ template <typename IdxT>
380+ struct MaskSegmentKey {
381+ IdxT segment_size;
382+
383+ __host__ __device__ IdxT operator ()(IdxT i) const {
384+ return i / segment_size;
385+ }
386+ };
387+
388+ } // namespace
389+
390+ void segmented_exclusive_mask_scan_gpu (
391+ const array& in,
368392 array& out,
369- Scan::ReduceType reduce_type,
370- int axis,
371- bool reverse,
372- bool inclusive,
393+ int64_t segment_size,
373394 const Stream& s) {
395+ if (segment_size <= 0 ) {
396+ throw std::runtime_error (" segment_size must be positive." );
397+ }
398+
374399 auto & encoder = cu::get_command_encoder (s);
400+ encoder.set_input_array (in);
401+ encoder.set_output_array (out);
402+
403+ using CubIdx = int64_t ;
404+ auto count_iter = thrust::counting_iterator<CubIdx>(0 );
405+ auto key_iter = thrust::make_transform_iterator (
406+ count_iter, MaskSegmentKey<CubIdx>{static_cast <CubIdx>(segment_size)});
407+ auto value_iter =
408+ thrust::make_transform_iterator (gpu_ptr<bool >(in), BoolToInt32{});
409+
410+ size_t workspace_size = 0 ;
411+ if (segment_size == static_cast <int64_t >(in.size ())) {
412+ CHECK_CUDA_ERROR (
413+ cub::DeviceScan::ExclusiveSum (
414+ nullptr ,
415+ workspace_size,
416+ value_iter,
417+ gpu_ptr<int32_t >(out),
418+ static_cast <CubIdx>(in.size ()),
419+ encoder.stream ()));
420+
421+ void * workspace = allocate_workspace (encoder, workspace_size);
422+ auto capture = encoder.capture_context ();
423+ CHECK_CUDA_ERROR (
424+ cub::DeviceScan::ExclusiveSum (
425+ workspace,
426+ workspace_size,
427+ value_iter,
428+ gpu_ptr<int32_t >(out),
429+ static_cast <CubIdx>(in.size ()),
430+ encoder.stream ()));
431+ } else {
432+ CHECK_CUDA_ERROR (
433+ cub::DeviceScan::ExclusiveSumByKey (
434+ nullptr ,
435+ workspace_size,
436+ key_iter,
437+ value_iter,
438+ gpu_ptr<int32_t >(out),
439+ static_cast <CubIdx>(in.size ()),
440+ cuda::std::equal_to<>{},
441+ encoder.stream ()));
442+
443+ void * workspace = allocate_workspace (encoder, workspace_size);
444+ auto capture = encoder.capture_context ();
445+ CHECK_CUDA_ERROR (
446+ cub::DeviceScan::ExclusiveSumByKey (
447+ workspace,
448+ workspace_size,
449+ key_iter,
450+ value_iter,
451+ gpu_ptr<int32_t >(out),
452+ static_cast <CubIdx>(in.size ()),
453+ cuda::std::equal_to<>{},
454+ encoder.stream ()));
455+ }
456+ return ;
457+ }
458+
459+ void Scan::eval_gpu (const std::vector<array>& inputs, array& out) {
460+ nvtx3::scoped_range r (" Scan::eval_gpu" );
461+ assert (inputs.size () == 1 );
462+ auto in = inputs[0 ];
463+ auto & s = stream ();
464+ auto & encoder = cu::get_command_encoder (s);
465+
466+ if (in.flags ().contiguous && in.strides ()[axis_] != 0 ) {
467+ if (in.is_donatable () && in.itemsize () == out.itemsize ()) {
468+ out.copy_shared_buffer (in);
469+ } else {
470+ out.set_data (
471+ cu::malloc_async (in.data_size () * out.itemsize (), encoder),
472+ in.data_size (),
473+ in.strides (),
474+ in.flags ());
475+ }
476+ } else {
477+ in = contiguous_copy_gpu (in, s);
478+ out.copy_shared_buffer (in);
479+ }
480+
375481 constexpr int N_READS = 4 ;
376- int32_t axis_size = in.shape (axis );
377- bool contiguous = in.strides ()[axis ] == 1 ;
482+ int32_t axis_size = in.shape (axis_ );
483+ bool contiguous = in.strides ()[axis_ ] == 1 ;
378484
379485 encoder.set_input_array (in);
380486 encoder.set_output_array (out);
381487
382488 dispatch_all_types (in.dtype (), [&](auto type_tag) {
383489 using T = cuda_type_t <MLX_GET_TYPE (type_tag)>;
384- dispatch_scan_ops (reduce_type , [&](auto scan_op_tag) {
490+ dispatch_scan_ops (reduce_type_ , [&](auto scan_op_tag) {
385491 using Op = MLX_GET_TYPE (scan_op_tag);
386492 if constexpr (supports_scan_op<Op, T>()) {
387493 using U = typename cu::ScanResult<Op, T>::type;
388- dispatch_bool (inclusive , [&](auto inclusive_tag ) {
389- dispatch_bool (reverse , [&](auto reverse_tag ) {
494+ dispatch_bool (inclusive_ , [&](auto inclusive ) {
495+ dispatch_bool (reverse_ , [&](auto reverse ) {
390496 if (contiguous) {
391497 auto kernel = cu::contiguous_scan<
392498 T,
393499 U,
394500 Op,
395501 N_READS,
396- inclusive_tag .value ,
397- reverse_tag .value >;
502+ inclusive .value ,
503+ reverse .value >;
398504 int block_dim = cuda::ceil_div (axis_size, N_READS);
399505 block_dim = cuda::ceil_div (block_dim, WARP_SIZE) * WARP_SIZE;
400506 block_dim = std::min (block_dim, WARP_SIZE * WARP_SIZE);
@@ -415,9 +521,9 @@ void scan_gpu_inplace(
415521 N_READS,
416522 BM,
417523 BN,
418- inclusive_tag .value ,
419- reverse_tag .value >;
420- int64_t stride = in.strides ()[axis ];
524+ inclusive .value ,
525+ reverse .value >;
526+ int64_t stride = in.strides ()[axis_ ];
421527 int64_t stride_blocks = cuda::ceil_div (stride, BN);
422528 dim3 num_blocks = get_2d_grid_dims (
423529 in.shape (), in.strides (), axis_size * stride);
@@ -451,29 +557,4 @@ void scan_gpu_inplace(
451557 });
452558}
453559
454- void Scan::eval_gpu (const std::vector<array>& inputs, array& out) {
455- nvtx3::scoped_range r (" Scan::eval_gpu" );
456- assert (inputs.size () == 1 );
457- auto in = inputs[0 ];
458- auto & s = stream ();
459- auto & encoder = cu::get_command_encoder (s);
460-
461- if (in.flags ().contiguous && in.strides ()[axis_] != 0 ) {
462- if (in.is_donatable () && in.itemsize () == out.itemsize ()) {
463- out.copy_shared_buffer (in);
464- } else {
465- out.set_data (
466- cu::malloc_async (in.data_size () * out.itemsize (), encoder),
467- in.data_size (),
468- in.strides (),
469- in.flags ());
470- }
471- } else {
472- in = contiguous_copy_gpu (in, s);
473- out.copy_shared_buffer (in);
474- }
475-
476- scan_gpu_inplace (in, out, reduce_type_, axis_, reverse_, inclusive_, s);
477- }
478-
479560} // namespace mlx::core
0 commit comments