Skip to content

Commit af3373f

Browse files
authored
HIP: enable vec fattn on RDNA4 (#14323)
1 parent 5d5c066 commit af3373f

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,18 @@ static bool fp16_mma_available(const int cc) {
241241
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
242242
return false;
243243
#else
244-
return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
245-
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
244+
if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
245+
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) {
246+
return true;
247+
} else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
248+
#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
249+
return true;
250+
#else
251+
return false;
252+
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
253+
} else {
254+
return false;
255+
}
246256
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
247257
}
248258

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,7 @@ int ggml_cuda_get_device() {
100100
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
101101
ggml_cuda_set_device(device);
102102
cudaError_t err;
103-
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
104-
{
103+
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr) {
105104
err = cudaMallocManaged(ptr, size);
106105
#if defined(GGML_USE_HIP)
107106
if (err == hipSuccess) {
@@ -119,9 +118,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
119118
err = cudaMalloc(ptr, size);
120119
}
121120
#endif // defined(GGML_USE_HIP)
122-
}
123-
else
124-
{
121+
} else {
125122
err = cudaMalloc(ptr, size);
126123
}
127124
return err;

0 commit comments

Comments
 (0)