[CUDA][Performance] Add radix select implementation for efficient partition operations#3117
[CUDA][Performance] Add radix select implementation for efficient partition operations#3117Lyxot wants to merge 18 commits intoml-explore:mainfrom
Conversation
|
I got the following benchmark results on the 4070 Super |
|
Most performance is basically OK, but there are still some dtypes that need further optimization (float32) |
|
Benchmark results may vary with hardware. Further test is required. |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
@zcbenz Could you please review this PR? |
|
@Lyxot Thanks for your contributions, I'm not familiar with gpu radix sort and I have to do some homework before I can review, and I'm currently stuck solving a hard problem so it will take some time before I can look into this. Maybe other maintainers can take a look before I do. |
|
Hi @Lyxot thanks for the PR! I think part of my comment on #3069 also applies here. In short I think a PR that tries to address a smaller use case but is consistently better than the fallback would be much better. It would be shorter, the code would be simpler and more importantly we wouldn't have to either accept regressions or make some complicated heuristic for routing to the fallback. My suggestion to begin with is to only tackle the use case of small axes that fit in shared memory. That would cover for instance MoE expert selection since the number of tokens can vary from 1 to 10s of thousands but the axis is fairly small 8 to a few hundreds. This is also a use case where the particular implementation is slow. (as is also the case in #3069). |
|
@angeloskath I tuned the small-kernel which is fit in shared memory. If you prefer a simpler scope for this PR, I can remove the large-kernel path and keep only the small-kernel with fallback to merge sort for the rest. |
|
current performance of small kernel is: |
Fix two correctness issues in CUDA radix partition/argpartition: - In the large contiguous radix path, stop deriving row bases from `row * min(non-axis stride)` and compute row offsets with `elem_to_loc(...)` using non-axis shape/strides (matching merge-sort indexing behavior). - Keep stride arguments 64-bit end-to-end in radix-select kernels and launches (remove narrowing to `int` and related `INT32_MAX` guard). This fixes incorrect row addressing for valid contiguous non-linear layouts (e.g. column-major with axis=0) and avoids silent misindexing on large strides.
Eliminate MAX_NDIM-based rank limits in CUDA radix partition/argpartition by switching radix kernels from fixed-size __grid_constant__ shape/stride params to dynamic device pointers for non-axis metadata. Changes: - Update radix kernels to take dynamic NC metadata pointers: - radix_select_small_nc_kernel - radix_select_large_streaming_kernel - radix_select_large_streaming_nc_kernel - In gpu_radix_partition_small/gpu_radix_partition_large: - allocate device buffers for nc_shape/in_nc_strides/out_nc_strides - copy host metadata with cudaMemcpyAsync - pass pointers into kernel launches - Remove MAX_NDIM-dependent routing so high-rank tensors can still use radix partition path. - Keep stride handling 64-bit end-to-end in radix launches/kernels. Also slightly widens fallback-model threshold range (without changing max_rows).
remove fallback strategy
based on estimated shared-memory usage and device limits
|
@angeloskath I’ve narrowed the scope of this PR. This version now only targets the small-axis case that fits in shared memory, and falls back to merge sort for the remaining cases. I also removed the larger-kernel path to keep the implementation smaller and avoid extra routing / heuristic complexity. Could you please take another look? |
Proposed changes
This adds a CUDA radix-select based path for
argpartitionpartitionand introduces multi-block-per-row and multi-row-per-block for shapes where normal radix select underperforms. #3064What changed
mlx/backend/cuda/device/radix_select.cuh:blocks_per_row,rows_per_block)mlx/backend/cuda/sort.cu:ArgPartition::eval_gpu/Partition::eval_gpunow call radix partition pathbenchmarks/python/radix_select_bench.py: Correctness checks, determinism checks, and performance sweep utilitiesChecklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes