Skip to content

Commit 0142961

Browse files
CUDA/HIP: optimize mmv paths taken for HIP devices (#14324)
Co-authored-by: Johannes Gäßler <[email protected]>
1 parent ce82bd0 commit 0142961

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,11 @@ static bool fp16_mma_hardware_available(const int cc) {
263263
}
264264

265265
static bool bf16_mma_hardware_available(const int cc) {
266-
return GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE;
266+
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) || GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
267+
}
268+
269+
static bool fp32_mma_hardware_available(const int cc) {
270+
return GGML_CUDA_CC_IS_CDNA(cc);
267271
}
268272

269273
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.

ggml/src/ggml-cuda/mmv.cu

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,11 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_
456456
return ne11 <= 4;
457457
}
458458
return ne11 <= 3;
459+
} else if (GGML_CUDA_CC_IS_AMD(cc)) {
460+
if (fp32_mma_hardware_available(cc)) {
461+
return ne11 <= 3;
462+
}
463+
return ne11 <= 8;
459464
}
460465
return ne11 <= 8;
461466
case GGML_TYPE_F16:
@@ -468,6 +473,14 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_
468473
return src0_small && ne11 <= 3;
469474
}
470475
return ne11 <= 8;
476+
} else if (GGML_CUDA_CC_IS_AMD(cc)) {
477+
if (fp16_mma_hardware_available(cc)) {
478+
if (GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
479+
return ne11 <= 5;
480+
}
481+
return ne11 <= 2;
482+
}
483+
return ne11 <= 8;
471484
}
472485
return ne11 <= 8;
473486
case GGML_TYPE_BF16:
@@ -480,6 +493,11 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_
480493
return src0_small && ne11 <= 3;
481494
}
482495
return ne11 <= 8;
496+
} else if (GGML_CUDA_CC_IS_AMD(cc)) {
497+
if (bf16_mma_hardware_available(cc)) {
498+
return ne11 <= 3;
499+
}
500+
return ne11 <= 8;
483501
}
484502
return ne11 <= 8;
485503
default:

0 commit comments

Comments
 (0)