Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions benchmarks/python/masked_scatter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
import os
import platform
import subprocess
import time
from copy import copy
Expand All @@ -17,21 +18,43 @@
if not os.path.isdir(RESULTS_DIR):
os.mkdir(RESULTS_DIR)

DEVICE_NAME = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
DEVICE_NAME = DEVICE_NAME.decode("utf-8").strip("\n")

TORCH_DEVICE = torch.device(
"mps"
if torch.backends.mps.is_available()
else ("cuda" if torch.cuda.is_available() else "cpu")
)


def get_device_name():
if TORCH_DEVICE.type == "cuda":
try:
out = subprocess.check_output(
["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"],
stderr=subprocess.DEVNULL,
)
return out.decode("utf-8").splitlines()[0].strip()
except Exception:
return "CUDA_GPU"
if TORCH_DEVICE.type == "mps":
try:
out = subprocess.check_output(
["sysctl", "-n", "machdep.cpu.brand_string"],
stderr=subprocess.DEVNULL,
)
return out.decode("utf-8").strip()
except Exception:
return "Apple_Silicon"
return platform.processor() or platform.machine() or "CPU"


DEVICE_NAME = get_device_name()


N_WARMUP = 5
N_ITER_BENCH = 50
N_ITER_FUNC = 20

VECTOR_LENGTHS = [4096 * (2**i) for i in range(10)]
VECTOR_LENGTHS = [4096 * (2**i) for i in range(12)]
MASK_DENSITIES = [0.01, 0.1, 0.25, 0.5]
D_TYPES = ("float32", "float16")

Expand Down Expand Up @@ -202,9 +225,10 @@ def main():
)
output_path = os.path.join(
RESULTS_DIR,
f"{DEVICE_NAME.replace(' ', '_')}_masked_scatter_{dtype}.pdf",
f"{DEVICE_NAME.replace(' ', '_')}_masked_scatter_{dtype}.png",
)
fig.savefig(output_path)
print(f"Saved benchmark image: {output_path}")
plt.close(fig)


Expand Down
183 changes: 183 additions & 0 deletions mlx/backend/cuda/device/scatter.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,187 @@ __global__ void scatter(
Op{}(out + out_idx, upd[upd_loc]);
}

template <typename T, bool SrcContiguous, bool DstContiguous, typename IdxT>
__global__ void masked_scatter_fused(
const T* dst,
const bool* mask,
const int32_t* scatter_offsets,
const T* src,
T* out,
IdxT size,
IdxT src_batch_size,
IdxT mask_batch_size,
const __grid_constant__ Shape dst_shape,
const __grid_constant__ Strides dst_strides,
int32_t dst_ndim,
const __grid_constant__ Shape src_shape,
const __grid_constant__ Strides src_strides,
int32_t src_ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index >= size) {
return;
}

T dst_val;
if constexpr (DstContiguous) {
dst_val = dst[index];
} else {
IdxT dst_loc =
elem_to_loc(index, dst_shape.data(), dst_strides.data(), dst_ndim);
dst_val = dst[dst_loc];
}

if (mask[index]) {
IdxT src_index = static_cast<IdxT>(scatter_offsets[index]);
if (src_index < src_batch_size) {
IdxT batch_idx = index / mask_batch_size;
if constexpr (SrcContiguous) {
out[index] = src[batch_idx * src_batch_size + src_index];
} else {
IdxT src_elem = batch_idx * src_batch_size + src_index;
IdxT src_loc = elem_to_loc(
src_elem, src_shape.data(), src_strides.data(), src_ndim);
out[index] = src[src_loc];
}
return;
}
}

out[index] = dst_val;
}

template <typename IdxT, int ITEMS_PER_THREAD>
__global__ void masked_scatter_tile_count(
const bool* mask,
int32_t* tile_counts,
IdxT mask_batch_size,
int32_t num_tiles_per_batch) {
IdxT tile = cg::this_grid().block_rank();
IdxT batch_idx = tile / num_tiles_per_batch;
IdxT tile_in_batch = tile - batch_idx * num_tiles_per_batch;
IdxT tile_items = static_cast<IdxT>(blockDim.x) * ITEMS_PER_THREAD;
IdxT tile_start = batch_idx * mask_batch_size + tile_in_batch * tile_items;
IdxT batch_end = (batch_idx + 1) * mask_batch_size;
IdxT tile_end = tile_start + tile_items;
if (tile_end > batch_end) {
tile_end = batch_end;
}

int32_t local_count = 0;
IdxT index = tile_start + threadIdx.x;
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
if (index < tile_end) {
local_count += static_cast<int32_t>(mask[index]);
}
index += blockDim.x;
}

int lane = threadIdx.x & (WARP_SIZE - 1);
int warp = threadIdx.x / WARP_SIZE;
int nwarps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE;

unsigned int active = __activemask();
for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) {
local_count += __shfl_down_sync(active, local_count, offset);
}

__shared__ int32_t warp_sums[WARP_SIZE];
if (lane == 0) {
warp_sums[warp] = local_count;
}
__syncthreads();

if (warp == 0) {
int32_t block_sum = (lane < nwarps) ? warp_sums[lane] : 0;
unsigned int warp0_active = __activemask();
for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) {
block_sum += __shfl_down_sync(warp0_active, block_sum, offset);
}
if (lane == 0) {
tile_counts[tile] = block_sum;
}
}
}

template <typename T, typename IdxT, int ITEMS_PER_THREAD>
__global__ void masked_scatter_fused_vec_contiguous(
const T* dst,
const bool* mask,
const int32_t* tile_offsets,
const T* src,
T* out,
IdxT src_batch_size,
IdxT mask_batch_size,
int32_t num_tiles_per_batch) {
IdxT tile = cg::this_grid().block_rank();
IdxT batch_idx = tile / num_tiles_per_batch;
IdxT tile_in_batch = tile - batch_idx * num_tiles_per_batch;
IdxT tile_items = static_cast<IdxT>(blockDim.x) * ITEMS_PER_THREAD;
IdxT tile_start = batch_idx * mask_batch_size + tile_in_batch * tile_items;
IdxT batch_end = (batch_idx + 1) * mask_batch_size;
IdxT tile_end = tile_start + tile_items;
if (tile_end > batch_end) {
tile_end = batch_end;
}

IdxT src_base = batch_idx * src_batch_size;
IdxT tile_prefix = static_cast<IdxT>(tile_offsets[tile]);
IdxT iter_prefix = 0;

int lane = threadIdx.x & (WARP_SIZE - 1);
int warp = threadIdx.x / WARP_SIZE;
int nwarps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE;

__shared__ int32_t warp_counts[WARP_SIZE];
__shared__ int32_t warp_offsets[WARP_SIZE];
__shared__ int32_t iter_count;

IdxT index = tile_start + threadIdx.x;

#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
bool active = index < tile_end;
bool mask_value = active ? mask[index] : false;
T out_value = active ? dst[index] : static_cast<T>(0);

unsigned int active_mask = __activemask();
unsigned int ballots = __ballot_sync(active_mask, mask_value);
unsigned int lane_mask = (lane == 0) ? 0u : ((1u << lane) - 1u);
int32_t warp_exclusive = __popc(ballots & lane_mask);
int32_t warp_count = __popc(ballots);

if (lane == 0) {
warp_counts[warp] = warp_count;
}
__syncthreads();

if (threadIdx.x == 0) {
int32_t offset = 0;
for (int w = 0; w < nwarps; ++w) {
warp_offsets[w] = offset;
offset += warp_counts[w];
}
iter_count = offset;
}
__syncthreads();

if (active && mask_value) {
IdxT src_index = tile_prefix + iter_prefix +
static_cast<IdxT>(warp_offsets[warp] + warp_exclusive);
if (src_index < src_batch_size) {
out_value = src[src_base + src_index];
}
}

if (active) {
out[index] = out_value;
}

iter_prefix += static_cast<IdxT>(iter_count);
index += blockDim.x;
__syncthreads();
}
}

} // namespace mlx::core::cu
Loading