Skip to content

Commit aa064b2

Browse files
authored
CUDA: add mean operation (#14313)
* CUDA: add mean operation * add back sum_rows_f32_cuda * Review: early exit if col!=0
1 parent aa0ef5c commit aa064b2

File tree

7 files changed

+54
-19
lines changed

7 files changed

+54
-19
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,26 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
362362
#endif // FP16_AVAILABLE
363363
}
364364

365+
// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
366+
template<bool norm>
367+
static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) {
368+
const int row = blockIdx.x;
369+
const int col = threadIdx.x;
370+
371+
float sum = 0.0f;
372+
for (int i = col; i < ncols; i += blockDim.x) {
373+
sum += x[row * ncols + i];
374+
}
375+
376+
sum = warp_reduce_sum(sum);
377+
378+
if (col != 0) {
379+
return;
380+
}
381+
382+
dst[row] = norm ? sum / ncols : sum;
383+
}
384+
365385
template<int width = WARP_SIZE>
366386
static __device__ __forceinline__ float warp_reduce_max(float x) {
367387
#pragma unroll

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include "ggml-cuda/ssm-scan.cuh"
3838
#include "ggml-cuda/sum.cuh"
3939
#include "ggml-cuda/sumrows.cuh"
40+
#include "ggml-cuda/mean.cuh"
4041
#include "ggml-cuda/tsembd.cuh"
4142
#include "ggml-cuda/unary.cuh"
4243
#include "ggml-cuda/upscale.cuh"
@@ -2357,6 +2358,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
23572358
case GGML_OP_SUM_ROWS:
23582359
ggml_cuda_op_sum_rows(ctx, dst);
23592360
break;
2361+
case GGML_OP_MEAN:
2362+
ggml_cuda_op_mean(ctx, dst);
2363+
break;
23602364
case GGML_OP_SSM_CONV:
23612365
ggml_cuda_op_ssm_conv(ctx, dst);
23622366
break;
@@ -3260,6 +3264,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
32603264
case GGML_OP_POOL_2D:
32613265
case GGML_OP_SUM:
32623266
case GGML_OP_SUM_ROWS:
3267+
case GGML_OP_MEAN:
32633268
case GGML_OP_ARGSORT:
32643269
case GGML_OP_ACC:
32653270
return true;

ggml/src/ggml-cuda/mean.cu

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#include "mean.cuh"
2+
3+
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
4+
const ggml_tensor * src0 = dst->src[0];
5+
const float * src0_d = (const float *) src0->data;
6+
float * dst_d = (float *) dst->data;
7+
cudaStream_t stream = ctx.stream();
8+
9+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
10+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
11+
GGML_ASSERT(ggml_is_contiguous(src0));
12+
13+
const int64_t ncols = src0->ne[0];
14+
const int64_t nrows = ggml_nrows(src0);
15+
16+
const dim3 block_dims(WARP_SIZE, 1, 1);
17+
const dim3 block_nums(nrows, 1, 1);
18+
reduce_rows_f32</*norm*/ true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
19+
}

ggml/src/ggml-cuda/mean.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#include "common.cuh"
2+
3+
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

ggml/src/ggml-cuda/sumrows.cu

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,9 @@
11
#include "sumrows.cuh"
22

3-
static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
4-
const int row = blockIdx.x;
5-
const int col = threadIdx.x;
6-
7-
float sum = 0.0f;
8-
for (int i = col; i < ncols; i += blockDim.x) {
9-
sum += x[row * ncols + i];
10-
}
11-
12-
sum = warp_reduce_sum(sum);
13-
14-
if (col == 0) {
15-
dst[row] = sum;
16-
}
17-
}
18-
193
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
204
const dim3 block_dims(WARP_SIZE, 1, 1);
215
const dim3 block_nums(nrows, 1, 1);
22-
k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
6+
reduce_rows_f32</*norm*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
237
}
248

259
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -35,5 +19,8 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
3519
const int64_t ncols = src0->ne[0];
3620
const int64_t nrows = ggml_nrows(src0);
3721

38-
sum_rows_f32_cuda(src0_d, dst_d, ncols, nrows, stream);
22+
const dim3 block_dims(WARP_SIZE, 1, 1);
23+
const dim3 block_nums(nrows, 1, 1);
24+
25+
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
3926
}

ggml/src/ggml-cuda/sumrows.cuh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#include "common.cuh"
22

33
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);
4-
54
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

tests/test-backend-ops.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4652,6 +4652,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
46524652

46534653
test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1));
46544654

4655+
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1}));
4656+
46554657
return test_cases;
46564658
}
46574659

0 commit comments

Comments
 (0)