Skip to content
This repository was archived by the owner on Jan 13, 2025. It is now read-only.

Commit 3adb52c

Browse files
Enabled half precision for GEMM (#495)
* Enabled axpy and scal with half precision as well * Enabled the relevant tests and benchmarks with half precision
1 parent cb69d68 commit 3adb52c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+775
-394
lines changed

CMakeLists.txt

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,14 @@ if(IMGDNN_DIR)
106106
endif()
107107

108108
option(BLAS_ENABLE_EXTENSIONS "Whether to enable portBLAS extensions" ON)
109-
option(BLAS_ENABLE_COMPLEX "Whether to enable complex data type for supported operators" ON)
109+
option(BLAS_ENABLE_COMPLEX "Whether to enable complex data type for GEMM" OFF)
110+
option(BLAS_ENABLE_HALF "Whether to enable sycl::half data type for supported operators" OFF)
111+
112+
if(((NOT INSTALL_HEADER_ONLY) AND (TUNING_TARGET STREQUAL "DEFAULT_CPU"))
113+
OR (INSTALL_HEADER_ONLY AND (NOT TUNING_TARGET)))
114+
set(BLAS_ENABLE_HALF OFF)
115+
message(STATUS "FP16 operations are not supported for CPU targets. BLAS_ENABLE_HALF is disabled")
116+
endif()
110117

111118
# CmakeFunctionHelper has to be included after any options that it depends on are declared.
112119
# These include:
@@ -117,6 +124,8 @@ option(BLAS_ENABLE_COMPLEX "Whether to enable complex data type for supported op
117124
# * BLAS_INDEX_TYPES
118125
# * NAIVE_GEMM
119126
# * BLAS_ENABLE_COMPLEX
127+
# * BLAS_ENABLE_HALF
128+
120129
include(CmakeFunctionHelper)
121130

122131
if (INSTALL_HEADER_ONLY)

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -462,9 +462,10 @@ Some of the supported options are:
462462
| `BLAS_MEMPOOL_BENCHMARK` | `ON`/`OFF` | Determines whether to enable the scratchpad memory pool for benchmark execution. `OFF` by default |
463463
| `BLAS_ENABLE_CONST_INPUT` | `ON`/`OFF` | Determines whether to enable kernel instantiation with const input buffer (`ON` by default) |
464464
| `BLAS_ENABLE_EXTENSIONS` | `ON`/`OFF` | Determines whether to enable portBLAS extensions (`ON` by default) |
465-
| `BLAS_DATA_TYPES` | `half;float;double` | Determines the floating-point types to instantiate BLAS operations for. Default is `float` |
465+
| `BLAS_DATA_TYPES` | `float;double` | Determines the floating-point types to instantiate BLAS operations for. Default is `float` |
466466
| `BLAS_INDEX_TYPES` | `int32_t;int64_t` | Determines the type(s) to use for `index_t` and `increment_t`. Default is `int` |
467-
| `BLAS_ENABLE_COMPLEX` | `ON`/`OFF` | Determines whether to enable Complex data type support *(GEMM Operators only)* (`ON` by default) |
467+
| `BLAS_ENABLE_COMPLEX` | `ON`/`OFF` | Determines whether to enable Complex data type support *(GEMM Operators only)* (`OFF` by default) |
468+
| `BLAS_ENABLE_HALF` | `ON`/`OFF` | Determines whether to enable Half data type support *(Support is limited to some Level 1 operators and Gemm)* (`OFF` by default) |
468469

469470
## ComputeCpp Compilation *(Deprecated)*
470471

benchmark/cublas/CMakeLists.txt

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,18 +75,27 @@ set(sources
7575
)
7676

7777
# Operators supporting COMPLEX types benchmarking
78-
set(CPLX_OPS "gemm" "gemm_batched" "gemm_batched_strided")
78+
set(CPLX_OPS "gemm"
79+
"gemm_batched"
80+
"gemm_batched_strided")
81+
82+
# Operators supporting HALF type benchmarking
83+
set(HALF_DATA_OPS "gemm"
84+
"gemm_batched"
85+
"gemm_batched_strided"
86+
)
7987

8088
# Add individual benchmarks for each method
8189
foreach(cublas_bench ${sources})
8290
get_filename_component(bench_cublas_exec ${cublas_bench} NAME_WE)
8391
add_executable(bench_cublas_${bench_cublas_exec} ${cublas_bench} main.cpp)
8492
target_link_libraries(bench_cublas_${bench_cublas_exec} PRIVATE benchmark CUDA::toolkit CUDA::cublas CUDA::cudart portblas Clara::Clara bench_info)
8593
target_compile_definitions(bench_cublas_${bench_cublas_exec} PRIVATE -DBLAS_INDEX_T=${BLAS_BENCHMARK_INDEX_TYPE})
86-
if(${BLAS_ENABLE_COMPLEX})
87-
if("${bench_cublas_exec}" IN_LIST CPLX_OPS)
88-
target_compile_definitions(bench_cublas_${bench_cublas_exec} PRIVATE BLAS_ENABLE_COMPLEX=1)
89-
endif()
94+
if((${BLAS_ENABLE_COMPLEX}) AND ("${bench_cublas_exec}" IN_LIST CPLX_OPS))
95+
target_compile_definitions(bench_cublas_${bench_cublas_exec} PRIVATE BLAS_ENABLE_COMPLEX=1)
96+
endif()
97+
if((${BLAS_ENABLE_HALF}) AND ("${bench_cublas_exec}" IN_LIST HALF_DATA_OPS))
98+
target_compile_definitions(bench_cublas_${bench_cublas_exec} PRIVATE BLAS_ENABLE_HALF=1)
9099
endif()
91100
add_sycl_to_target(
92101
TARGET bench_cublas_${bench_cublas_exec}

benchmark/cublas/blas3/gemm.cpp

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ static inline void cublas_routine(args_t&&... args) {
3434
CUBLAS_CHECK(cublasSgemm(std::forward<args_t>(args)...));
3535
} else if constexpr (std::is_same_v<scalar_t, double>) {
3636
CUBLAS_CHECK(cublasDgemm(std::forward<args_t>(args)...));
37+
} else if constexpr (std::is_same_v<scalar_t, cl::sycl::half>) {
38+
CUBLAS_CHECK(cublasHgemm(std::forward<args_t>(args)...));
3739
}
3840
return;
3941
}
@@ -54,6 +56,10 @@ template <typename scalar_t>
5456
void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, int t1,
5557
int t2, index_t m, index_t k, index_t n, scalar_t alpha, scalar_t beta,
5658
bool* success) {
59+
// scalar_t if scalar_t!=sycl::half, cuda::__half otherwise
60+
using cuda_scalar_t =
61+
typename blas_benchmark::utils::CudaType<scalar_t>::type;
62+
5763
// initialize the state label
5864
blas_benchmark::utils::set_benchmark_label<scalar_t>(state);
5965

@@ -80,24 +86,31 @@ void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, int t1,
8086
std::vector<scalar_t> c =
8187
blas_benchmark::utils::const_data<scalar_t>(m * n, 0);
8288

83-
blas_benchmark::utils::CUDAVector<scalar_t> a_gpu(m * k, a.data());
84-
blas_benchmark::utils::CUDAVector<scalar_t> b_gpu(k * n, b.data());
85-
blas_benchmark::utils::CUDAVector<scalar_t> c_gpu(n * m, c.data());
89+
blas_benchmark::utils::CUDAVector<cuda_scalar_t> a_gpu(
90+
m * k, reinterpret_cast<cuda_scalar_t*>(a.data()));
91+
blas_benchmark::utils::CUDAVector<cuda_scalar_t> b_gpu(
92+
k * n, reinterpret_cast<cuda_scalar_t*>(b.data()));
93+
blas_benchmark::utils::CUDAVector<cuda_scalar_t> c_gpu(
94+
n * m, reinterpret_cast<cuda_scalar_t*>(c.data()));
8695

8796
cublasOperation_t c_t_a = (*t_a == 'n') ? CUBLAS_OP_N : CUBLAS_OP_T;
8897
cublasOperation_t c_t_b = (*t_b == 'n') ? CUBLAS_OP_N : CUBLAS_OP_T;
8998

99+
cuda_scalar_t alpha_cuda = *reinterpret_cast<cuda_scalar_t*>(&alpha);
100+
cuda_scalar_t beta_cuda = *reinterpret_cast<cuda_scalar_t*>(&beta);
101+
90102
#ifdef BLAS_VERIFY_BENCHMARK
91103
// Run a first time with a verification of the results
92104
std::vector<scalar_t> c_ref = c;
93105
reference_blas::gemm(t_a, t_b, m, n, k, alpha, a.data(), lda, b.data(), ldb,
94106
beta, c_ref.data(), ldc);
95107
std::vector<scalar_t> c_temp = c;
96108
{
97-
blas_benchmark::utils::CUDAVector<scalar_t, true> c_temp_gpu(m * n,
98-
c_temp.data());
99-
cublas_routine<scalar_t>(cuda_handle, c_t_a, c_t_b, m, n, k, &alpha, a_gpu,
100-
lda, b_gpu, ldb, &beta, c_temp_gpu, ldc);
109+
blas_benchmark::utils::CUDAVector<cuda_scalar_t, true> c_temp_gpu(
110+
m * n, reinterpret_cast<cuda_scalar_t*>(c_temp.data()));
111+
cublas_routine<scalar_t>(cuda_handle, c_t_a, c_t_b, m, n, k, &alpha_cuda,
112+
a_gpu, lda, b_gpu, ldb, &beta_cuda, c_temp_gpu,
113+
ldc);
101114
}
102115

103116
std::ostringstream err_stream;
@@ -107,9 +120,10 @@ void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, int t1,
107120
*success = false;
108121
};
109122
#endif
123+
110124
auto blas_warmup = [&]() -> void {
111-
cublas_routine<scalar_t>(cuda_handle, c_t_a, c_t_b, m, n, k, &alpha, a_gpu,
112-
lda, b_gpu, ldb, &beta, c_gpu, ldc);
125+
cublas_routine<scalar_t>(cuda_handle, c_t_a, c_t_b, m, n, k, &alpha_cuda,
126+
a_gpu, lda, b_gpu, ldb, &beta_cuda, c_gpu, ldc);
113127
return;
114128
};
115129

@@ -120,8 +134,8 @@ void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, int t1,
120134

121135
auto blas_method_def = [&]() -> std::vector<cudaEvent_t> {
122136
CUDA_CHECK(cudaEventRecord(start));
123-
cublas_routine<scalar_t>(cuda_handle, c_t_a, c_t_b, m, n, k, &alpha, a_gpu,
124-
lda, b_gpu, ldb, &beta, c_gpu, ldc);
137+
cublas_routine<scalar_t>(cuda_handle, c_t_a, c_t_b, m, n, k, &alpha_cuda,
138+
a_gpu, lda, b_gpu, ldb, &beta_cuda, c_gpu, ldc);
125139
CUDA_CHECK(cudaEventRecord(stop));
126140
CUDA_CHECK(cudaEventSynchronize(stop));
127141
return std::vector{start, stop};

benchmark/cublas/blas3/gemm_batched.cpp

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ static inline void cublas_routine(args_t&&... args) {
3434
CUBLAS_CHECK(cublasSgemmBatched(std::forward<args_t>(args)...));
3535
} else if constexpr (std::is_same_v<scalar_t, double>) {
3636
CUBLAS_CHECK(cublasDgemmBatched(std::forward<args_t>(args)...));
37+
} else if constexpr (std::is_same_v<scalar_t, cl::sycl::half>) {
38+
CUBLAS_CHECK(cublasHgemmBatched(std::forward<args_t>(args)...));
3739
}
3840
return;
3941
}
@@ -54,6 +56,10 @@ template <typename scalar_t>
5456
void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, index_t t1,
5557
index_t t2, index_t m, index_t k, index_t n, scalar_t alpha,
5658
scalar_t beta, index_t batch_count, int batch_type_i, bool* success) {
59+
// scalar_t if scalar_t!=sycl::half, cuda::__half otherwise
60+
using cuda_scalar_t =
61+
typename blas_benchmark::utils::CudaType<scalar_t>::type;
62+
5763
// initialize the state label
5864
blas_benchmark::utils::set_benchmark_label<scalar_t>(state);
5965

@@ -84,17 +90,19 @@ void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, index_t t1,
8490
std::vector<scalar_t> c =
8591
blas_benchmark::utils::const_data<scalar_t>(m * n * batch_count, 0);
8692

87-
blas_benchmark::utils::CUDAVectorBatched<scalar_t> d_A_array(m * k,
88-
batch_count, a);
89-
blas_benchmark::utils::CUDAVectorBatched<scalar_t> d_B_array(k * n,
90-
batch_count, b);
91-
blas_benchmark::utils::CUDAVectorBatched<scalar_t> d_C_array(m * n,
92-
batch_count);
93+
blas_benchmark::utils::CUDAVectorBatched<cuda_scalar_t> d_A_array(
94+
m * k, batch_count, reinterpret_cast<cuda_scalar_t*>(a.data()));
95+
blas_benchmark::utils::CUDAVectorBatched<cuda_scalar_t> d_B_array(
96+
k * n, batch_count, reinterpret_cast<cuda_scalar_t*>(b.data()));
97+
blas_benchmark::utils::CUDAVectorBatched<cuda_scalar_t> d_C_array(
98+
m * n, batch_count);
9399

94100
cublasOperation_t c_t_a = (*t_a == 'n') ? CUBLAS_OP_N : CUBLAS_OP_T;
95-
96101
cublasOperation_t c_t_b = (*t_b == 'n') ? CUBLAS_OP_N : CUBLAS_OP_T;
97102

103+
cuda_scalar_t alpha_cuda = *reinterpret_cast<cuda_scalar_t*>(&alpha);
104+
cuda_scalar_t beta_cuda = *reinterpret_cast<cuda_scalar_t*>(&beta);
105+
98106
#ifdef BLAS_VERIFY_BENCHMARK
99107
// Run a first time with a verification of the results
100108
{
@@ -110,13 +118,12 @@ void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, index_t t1,
110118
}
111119

112120
std::vector<scalar_t> c_temp(m * n * batch_count);
113-
114121
{
115-
blas_benchmark::utils::CUDAVectorBatched<scalar_t, true> c_temp_gpu(
116-
n * m, batch_count, c_temp);
117-
cublas_routine<scalar_t>(cuda_handle, c_t_a, c_t_b, m, n, k, &alpha,
122+
blas_benchmark::utils::CUDAVectorBatched<cuda_scalar_t, true> c_temp_gpu(
123+
n * m, batch_count, reinterpret_cast<cuda_scalar_t*>(c_temp.data()));
124+
cublas_routine<scalar_t>(cuda_handle, c_t_a, c_t_b, m, n, k, &alpha_cuda,
118125
d_A_array.get_batch_array(), lda,
119-
d_B_array.get_batch_array(), ldb, &beta,
126+
d_B_array.get_batch_array(), ldb, &beta_cuda,
120127
c_temp_gpu.get_batch_array(), ldc, batch_count);
121128
}
122129

@@ -128,14 +135,13 @@ void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, index_t t1,
128135
*success = false;
129136
};
130137
}
131-
132138
} // close scope for verify benchmark
133139
#endif
134140

135141
auto blas_warmup = [&]() -> void {
136-
cublas_routine<scalar_t>(cuda_handle, c_t_a, c_t_b, m, n, k, &alpha,
142+
cublas_routine<scalar_t>(cuda_handle, c_t_a, c_t_b, m, n, k, &alpha_cuda,
137143
d_A_array.get_batch_array(), lda,
138-
d_B_array.get_batch_array(), ldb, &beta,
144+
d_B_array.get_batch_array(), ldb, &beta_cuda,
139145
d_C_array.get_batch_array(), ldc, batch_count);
140146
return;
141147
};
@@ -146,9 +152,9 @@ void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, index_t t1,
146152

147153
auto blas_method_def = [&]() -> std::vector<cudaEvent_t> {
148154
CUDA_CHECK(cudaEventRecord(start));
149-
cublas_routine<scalar_t>(cuda_handle, c_t_a, c_t_b, m, n, k, &alpha,
155+
cublas_routine<scalar_t>(cuda_handle, c_t_a, c_t_b, m, n, k, &alpha_cuda,
150156
d_A_array.get_batch_array(), lda,
151-
d_B_array.get_batch_array(), ldb, &beta,
157+
d_B_array.get_batch_array(), ldb, &beta_cuda,
152158
d_C_array.get_batch_array(), ldc, batch_count);
153159
CUDA_CHECK(cudaEventRecord(stop));
154160
CUDA_CHECK(cudaEventSynchronize(stop));

benchmark/cublas/blas3/gemm_batched_strided.cpp

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ static inline void cublas_routine(args_t&&... args) {
3434
CUBLAS_CHECK(cublasSgemmStridedBatched(std::forward<args_t>(args)...));
3535
} else if constexpr (std::is_same_v<scalar_t, double>) {
3636
CUBLAS_CHECK(cublasDgemmStridedBatched(std::forward<args_t>(args)...));
37+
} else if constexpr (std::is_same_v<scalar_t, cl::sycl::half>) {
38+
CUBLAS_CHECK(cublasHgemmStridedBatched(std::forward<args_t>(args)...));
3739
}
3840
return;
3941
}
@@ -55,6 +57,10 @@ void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, int t1,
5557
int t2, index_t m, index_t k, index_t n, scalar_t alpha, scalar_t beta,
5658
index_t batch_size, index_t stride_a_mul, index_t stride_b_mul,
5759
index_t stride_c_mul, bool* success) {
60+
// scalar_t if scalar_t!=sycl::half, cuda::__half otherwise
61+
using cuda_scalar_t =
62+
typename blas_benchmark::utils::CudaType<scalar_t>::type;
63+
5864
// initialize the state label
5965
blas_benchmark::utils::set_benchmark_label<scalar_t>(state);
6066

@@ -103,14 +109,19 @@ void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, int t1,
103109
std::vector<scalar_t> c =
104110
blas_benchmark::utils::const_data<scalar_t>(size_c_batch, 0);
105111

106-
blas_benchmark::utils::CUDAVector<scalar_t> a_gpu(size_a_batch, a.data());
107-
blas_benchmark::utils::CUDAVector<scalar_t> b_gpu(size_b_batch, b.data());
108-
blas_benchmark::utils::CUDAVector<scalar_t> c_gpu(size_c_batch, c.data());
112+
blas_benchmark::utils::CUDAVector<cuda_scalar_t> a_gpu(
113+
size_a_batch, reinterpret_cast<cuda_scalar_t*>(a.data()));
114+
blas_benchmark::utils::CUDAVector<cuda_scalar_t> b_gpu(
115+
size_b_batch, reinterpret_cast<cuda_scalar_t*>(b.data()));
116+
blas_benchmark::utils::CUDAVector<cuda_scalar_t> c_gpu(
117+
size_c_batch, reinterpret_cast<cuda_scalar_t*>(c.data()));
109118

110119
cublasOperation_t c_t_a = trA ? CUBLAS_OP_N : CUBLAS_OP_T;
111-
112120
cublasOperation_t c_t_b = trB ? CUBLAS_OP_N : CUBLAS_OP_T;
113121

122+
cuda_scalar_t alpha_cuda = *reinterpret_cast<cuda_scalar_t*>(&alpha);
123+
cuda_scalar_t beta_cuda = *reinterpret_cast<cuda_scalar_t*>(&beta);
124+
114125
#ifdef BLAS_VERIFY_BENCHMARK
115126
// Run a first time with a verification of the results
116127
std::vector<scalar_t> c_ref = c;
@@ -123,11 +134,11 @@ void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, int t1,
123134

124135
std::vector<scalar_t> c_temp = c;
125136
{
126-
blas_benchmark::utils::CUDAVector<scalar_t, true> c_temp_gpu(size_c_batch,
127-
c_temp.data());
128-
cublas_routine<scalar_t>(cuda_handle, c_t_a, c_t_b, m, n, k, &alpha, a_gpu,
129-
lda, stride_a, b_gpu, ldb, stride_b, &beta,
130-
c_temp_gpu, ldc, stride_c, batch_size);
137+
blas_benchmark::utils::CUDAVector<cuda_scalar_t, true> c_temp_gpu(
138+
size_c_batch, reinterpret_cast<cuda_scalar_t*>(c_temp.data()));
139+
cublas_routine<scalar_t>(cuda_handle, c_t_a, c_t_b, m, n, k, &alpha_cuda,
140+
a_gpu, lda, stride_a, b_gpu, ldb, stride_b,
141+
&beta_cuda, c_temp_gpu, ldc, stride_c, batch_size);
131142
}
132143

133144
std::ostringstream err_stream;
@@ -140,9 +151,9 @@ void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, int t1,
140151
#endif
141152

142153
auto blas_warmup = [&]() -> void {
143-
cublas_routine<scalar_t>(cuda_handle, c_t_a, c_t_b, m, n, k, &alpha, a_gpu,
144-
lda, stride_a, b_gpu, ldb, stride_b, &beta, c_gpu,
145-
ldc, stride_c, batch_size);
154+
cublas_routine<scalar_t>(cuda_handle, c_t_a, c_t_b, m, n, k, &alpha_cuda,
155+
a_gpu, lda, stride_a, b_gpu, ldb, stride_b,
156+
&beta_cuda, c_gpu, ldc, stride_c, batch_size);
146157
return;
147158
};
148159

@@ -152,9 +163,9 @@ void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, int t1,
152163

153164
auto blas_method_def = [&]() -> std::vector<cudaEvent_t> {
154165
CUDA_CHECK(cudaEventRecord(start));
155-
cublas_routine<scalar_t>(cuda_handle, c_t_a, c_t_b, m, n, k, &alpha, a_gpu,
156-
lda, stride_a, b_gpu, ldb, stride_b, &beta, c_gpu,
157-
ldc, stride_c, batch_size);
166+
cublas_routine<scalar_t>(cuda_handle, c_t_a, c_t_b, m, n, k, &alpha_cuda,
167+
a_gpu, lda, stride_a, b_gpu, ldb, stride_b,
168+
&beta_cuda, c_gpu, ldc, stride_c, batch_size);
158169
CUDA_CHECK(cudaEventRecord(stop));
159170
CUDA_CHECK(cudaEventSynchronize(stop));
160171
return std::vector{start, stop};

benchmark/cublas/utils.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@
3636
#include <cuComplex.h>
3737
#include <cublas_v2.h>
3838
#include <cuda.h>
39+
#include <cuda_fp16.h>
3940
#include <cuda_runtime.h>
41+
4042
// Forward declare methods that we use in `benchmark.cpp`, but define in
4143
// `main.cpp`
4244

@@ -274,6 +276,21 @@ static inline std::tuple<double, double> timef_cuda(function_t func,
274276
return std::make_tuple(overall_time, static_cast<double>(elapsed_time) * 1E6);
275277
}
276278

279+
/**
280+
* Reference type of the underlying benchmark data aimed to match the
281+
* cuda/cuBLAS scalar types.
282+
*/
283+
template <typename T, typename Enable = void>
284+
struct CudaType {
285+
using type = T;
286+
};
287+
288+
// When T is sycl::half, use cuda's __cuda as type.
289+
template <typename T>
290+
struct CudaType<T, std::enable_if_t<std::is_same_v<T, cl::sycl::half>>> {
291+
using type = __half;
292+
};
293+
277294
} // namespace utils
278295
} // namespace blas_benchmark
279296

0 commit comments

Comments
 (0)