@@ -1749,7 +1749,7 @@ static void ggml_cuda_op_mul_mat(
1749
1749
}
1750
1750
1751
1751
static __global__ void k_compute_batched_ptrs (
1752
- const half * src0_as_f16, const half * src1_as_f16, char * dst,
1752
+ const void * src0_as_f16, const void * src1_as_f16, char * dst,
1753
1753
const void ** ptrs_src, void ** ptrs_dst,
1754
1754
int64_t ne12, int64_t ne13,
1755
1755
int64_t ne23,
@@ -1772,91 +1772,139 @@ static __global__ void k_compute_batched_ptrs(
1772
1772
ptrs_dst[0 *ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
1773
1773
}
1774
1774
1775
- static void ggml_cuda_mul_mat_batched_cublas (ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1775
+ // Type traits for mapping ggml types to CUDA/cuBLAS types
1776
+ template <ggml_type T>
1777
+ struct batched_mul_mat_traits ;
1778
+
1779
+ template <>
1780
+ struct batched_mul_mat_traits <GGML_TYPE_F32> {
1781
+ using cuda_type = float ;
1782
+ static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
1783
+ static inline const cudaDataType_t data_type = CUDA_R_32F;
1784
+ static inline const ggml_type ggml_type_val = GGML_TYPE_F32;
1785
+ static inline const float alpha = 1 .0f ;
1786
+ static inline const float beta = 0 .0f ;
1787
+ static inline const void * get_alpha () { static const float val = alpha; return &val; }
1788
+ static inline const void * get_beta () { static const float val = beta; return &val; }
1789
+ static inline auto get_nc_converter (ggml_type src_type) { return ggml_get_to_fp32_nc_cuda (src_type); }
1790
+ };
1791
+
1792
+ template <>
1793
+ struct batched_mul_mat_traits <GGML_TYPE_BF16> {
1794
+ using cuda_type = nv_bfloat16;
1795
+ static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
1796
+ static inline const cudaDataType_t data_type = CUDA_R_16BF;
1797
+ static inline const ggml_type ggml_type_val = GGML_TYPE_BF16;
1798
+ static inline const float alpha = 1 .0f ;
1799
+ static inline const float beta = 0 .0f ;
1800
+ static inline const void * get_alpha () { static const float val = alpha; return &val; }
1801
+ static inline const void * get_beta () { static const float val = beta; return &val; }
1802
+ static inline auto get_nc_converter (ggml_type src_type) { return ggml_get_to_bf16_nc_cuda (src_type); }
1803
+ };
1804
+
1805
+ template <>
1806
+ struct batched_mul_mat_traits <GGML_TYPE_F16> {
1807
+ using cuda_type = half;
1808
+ static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
1809
+ static inline const cudaDataType_t data_type = CUDA_R_16F;
1810
+ static inline const ggml_type ggml_type_val = GGML_TYPE_F16;
1811
+ static inline const half alpha = 1.0 ;
1812
+ static inline const half beta = 0.0 ;
1813
+ static inline const void * get_alpha () { static const half val = alpha; return &val; }
1814
+ static inline const void * get_beta () { static const half val = beta; return &val; }
1815
+ static inline auto get_nc_converter (ggml_type src_type) { return ggml_get_to_fp16_nc_cuda (src_type); }
1816
+ };
1817
+
1818
+ template <ggml_type src0_type>
1819
+ static void ggml_cuda_mul_mat_batched_cublas_impl (ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1820
+ using traits = batched_mul_mat_traits<src0_type>;
1821
+ using cuda_t = typename traits::cuda_type;
1822
+
1776
1823
GGML_ASSERT (!ggml_is_transposed (src0));
1777
1824
GGML_ASSERT (!ggml_is_transposed (src1));
1778
-
1779
1825
GGML_ASSERT (!ggml_backend_buft_is_cuda_split (src0->buffer ->buft ));
1780
- GGML_ASSERT (src0->type == GGML_TYPE_F16);
1826
+ GGML_ASSERT (src0->type == src0_type);
1827
+ GGML_ASSERT (ggml_is_contiguous (dst));
1781
1828
1782
1829
// Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
1783
1830
// As long as dst is contiguous this does not matter though.
1784
- GGML_ASSERT (ggml_is_contiguous (dst));
1785
1831
1786
1832
GGML_TENSOR_BINARY_OP_LOCALS
1787
1833
1788
1834
const int64_t ne_dst = ggml_nelements (dst);
1789
-
1790
1835
cudaStream_t main_stream = ctx.stream ();
1791
-
1792
1836
CUBLAS_CHECK (cublasSetStream (ctx.cublas_handle (), main_stream));
1793
1837
1794
- const half * src0_f16 = (const half *) src0->data ;
1795
1838
float * dst_ddf = (float *) dst->data ;
1796
-
1797
- const half * src1_f16 = (const half *) src1->data ;
1798
1839
const size_t ts_src1 = ggml_type_size (src1->type );
1799
1840
GGML_ASSERT (nb10 == ts_src1);
1800
1841
int64_t s11 = nb11 / ts_src1;
1801
1842
int64_t s12 = nb12 / ts_src1;
1802
1843
int64_t s13 = nb13 / ts_src1;
1803
- ggml_cuda_pool_alloc<half> src1_f16_alloc (ctx.pool ());
1804
1844
1805
- // convert src1 to fp16
1806
- if (src1->type != GGML_TYPE_F16) {
1807
- const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda (src1->type );
1808
- const int64_t ne_src1 = ggml_nelements (src1);
1809
- src1_f16_alloc.alloc (ne_src1);
1810
- GGML_ASSERT (to_fp16_cuda != nullptr );
1845
+ const cuda_t * src0_ptr = nullptr ;
1846
+ const cuda_t * src1_ptr = nullptr ;
1811
1847
1812
- to_fp16_cuda (src1_f16, src1_f16_alloc.get (), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1848
+ ggml_cuda_pool_alloc<cuda_t > src0_alloc (ctx.pool ());
1849
+ ggml_cuda_pool_alloc<cuda_t > src1_alloc (ctx.pool ());
1850
+
1851
+ // Handle src0
1852
+ src0_ptr = (const cuda_t *) src0->data ;
1853
+
1854
+ // Handle src1 - convert if necessary
1855
+ if (src1->type == src0_type) {
1856
+ src1_ptr = (const cuda_t *) src1->data ;
1857
+ } else {
1858
+ // Convert src1 to target type using traits conversion functions
1859
+ const int64_t ne_src1 = ggml_nelements (src1);
1860
+ src1_alloc.alloc (ne_src1);
1813
1861
1814
- src1_f16 = src1_f16_alloc.get ();
1862
+ const auto convert_func = traits::get_nc_converter (src1->type );
1863
+ GGML_ASSERT (convert_func != nullptr );
1864
+ convert_func (src1->data , src1_alloc.get (), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1865
+ src1_ptr = src1_alloc.get ();
1815
1866
s11 = ne10;
1816
1867
s12 = ne11*s11;
1817
1868
s13 = ne12*s12;
1818
1869
}
1819
1870
1820
- ggml_cuda_pool_alloc<half> dst_f16 (ctx.pool ());
1871
+ // Setup destination buffer
1872
+ ggml_cuda_pool_alloc<cuda_t > dst_temp (ctx.pool ());
1821
1873
char * dst_t ;
1822
-
1823
- cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
1824
- cudaDataType_t cu_data_type = CUDA_R_16F;
1825
-
1826
- // dst strides
1827
1874
size_t nbd2 = dst->nb [2 ];
1828
1875
size_t nbd3 = dst->nb [3 ];
1829
1876
1830
- const half alpha_f16 = 1 .0f ;
1831
- const half beta_f16 = 0 .0f ;
1832
-
1877
+ cublasComputeType_t cu_compute_type = traits::compute_type;
1878
+ cudaDataType_t cu_data_type = traits::data_type;
1879
+ cudaDataType_t cu_data_type_a = traits::data_type;
1880
+ cudaDataType_t cu_data_type_b = traits::data_type;
1881
+ const void * alpha = traits::get_alpha ();
1882
+ const void * beta = traits::get_beta ();
1833
1883
const float alpha_f32 = 1 .0f ;
1834
- const float beta_f32 = 0 .0f ;
1835
-
1836
- const void * alpha = &alpha_f16;
1837
- const void * beta = &beta_f16;
1884
+ const float beta_f32 = 0 .0f ;
1838
1885
1839
1886
if (dst->op_params [0 ] == GGML_PREC_DEFAULT) {
1840
- dst_t = (char *) dst_f16.alloc (ne_dst);
1841
-
1842
- nbd2 /= sizeof (float ) / sizeof (half);
1843
- nbd3 /= sizeof (float ) / sizeof (half);
1887
+ if constexpr (src0_type == GGML_TYPE_F32) {
1888
+ dst_t = (char *) dst_ddf; // Direct F32 output
1889
+ } else {
1890
+ dst_t = (char *) dst_temp.alloc (ne_dst);
1891
+ nbd2 /= sizeof (float ) / sizeof (cuda_t );
1892
+ nbd3 /= sizeof (float ) / sizeof (cuda_t );
1893
+ }
1844
1894
} else {
1845
1895
dst_t = (char *) dst_ddf;
1846
-
1847
1896
cu_compute_type = CUBLAS_COMPUTE_32F;
1848
- cu_data_type = CUDA_R_32F;
1849
-
1897
+ cu_data_type = CUDA_R_32F;
1850
1898
alpha = &alpha_f32;
1851
- beta = &beta_f32;
1899
+ beta = &beta_f32;
1852
1900
}
1853
1901
1854
1902
int id = ggml_cuda_get_device ();
1855
1903
const int cc = ggml_cuda_info ().devices [id].cc ;
1856
1904
if (GGML_CUDA_CC_IS_CDNA (cc) || GGML_CUDA_CC_IS_RDNA4 (cc)) {
1857
1905
cu_compute_type = CUBLAS_COMPUTE_32F;
1858
1906
alpha = &alpha_f32;
1859
- beta = &beta_f32;
1907
+ beta = &beta_f32;
1860
1908
}
1861
1909
1862
1910
GGML_ASSERT (ne12 % ne02 == 0 );
@@ -1866,35 +1914,15 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1866
1914
const int64_t r2 = ne12/ne02;
1867
1915
const int64_t r3 = ne13/ne03;
1868
1916
1869
- #if 0
1870
- // use cublasGemmEx
1871
- {
1872
- for (int i13 = 0; i13 < ne13; ++i13) {
1873
- for (int i12 = 0; i12 < ne12; ++i12) {
1874
- int i03 = i13 / r3;
1875
- int i02 = i12 / r2;
1876
-
1877
- CUBLAS_CHECK(
1878
- cublasGemmEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
1879
- ne01, ne11, ne10,
1880
- alpha, (const char *) src0_f16 + i03*nb03 + i02*nb02, CUDA_R_16F, nb01/sizeof(half),
1881
- src1_f16 + i13*s13 + i12*s12, CUDA_R_16F, s11,
1882
- beta, ( char *) dst_t + i13*nbd3 + i12*nbd2, cu_data_type, ne0,
1883
- cu_compute_type,
1884
- CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1885
- }
1886
- }
1887
- }
1888
- #else
1889
1917
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2 (src0) && ggml_is_contiguous_2 (src1)) {
1890
1918
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
1891
1919
// use cublasGemmStridedBatchedEx
1892
1920
CUBLAS_CHECK (
1893
1921
cublasGemmStridedBatchedEx (ctx.cublas_handle (), CUBLAS_OP_T, CUBLAS_OP_N,
1894
1922
ne01, ne11, ne10,
1895
- alpha, src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA
1896
- src1_f16, CUDA_R_16F, s11, s12, // strideB
1897
- beta, dst_t , cu_data_type, ne0, ne1*ne0, // strideC
1923
+ alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
1924
+ src1_ptr, cu_data_type_b, s11, s12, // strideB
1925
+ beta, dst_t , cu_data_type, ne0, ne1*ne0, // strideC
1898
1926
ne12*ne13,
1899
1927
cu_compute_type,
1900
1928
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -1905,34 +1933,55 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1905
1933
ggml_cuda_pool_alloc<const void *> ptrs_src (ctx.pool (), 2 *ne23);
1906
1934
ggml_cuda_pool_alloc< void *> ptrs_dst (ctx.pool (), 1 *ne23);
1907
1935
1936
+ size_t src1_stride_size = sizeof (cuda_t );
1937
+
1908
1938
dim3 block_dims (ne13, ne12);
1909
1939
k_compute_batched_ptrs<<<1 , block_dims, 0 , main_stream>>> (
1910
- src0_f16, src1_f16 , dst_t ,
1940
+ src0_ptr, src1_ptr , dst_t ,
1911
1941
ptrs_src.get (), ptrs_dst.get (),
1912
1942
ne12, ne13,
1913
1943
ne23,
1914
1944
nb02, nb03,
1915
- src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof (half) ,
1916
- src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof (half) ,
1945
+ ( src1->type == src0_type) ? nb12 : s12*src1_stride_size ,
1946
+ ( src1->type == src0_type) ? nb13 : s13*src1_stride_size ,
1917
1947
nbd2, nbd3,
1918
1948
r2, r3);
1949
+
1919
1950
CUDA_CHECK (cudaGetLastError ());
1920
1951
1921
1952
CUBLAS_CHECK (
1922
1953
cublasGemmBatchedEx (ctx.cublas_handle (), CUBLAS_OP_T, CUBLAS_OP_N,
1923
1954
ne01, ne11, ne10,
1924
- alpha, (const void **) (ptrs_src.get () + 0 *ne23), CUDA_R_16F, nb01/nb00,
1925
- (const void **) (ptrs_src.get () + 1 *ne23), CUDA_R_16F, s11,
1926
- beta, ( void **) (ptrs_dst.get () + 0 *ne23), cu_data_type, ne0,
1955
+ alpha, (const void **) (ptrs_src.get () + 0 *ne23), cu_data_type_a, nb01/nb00,
1956
+ (const void **) (ptrs_src.get () + 1 *ne23), cu_data_type_b, s11,
1957
+ beta, ( void **) (ptrs_dst.get () + 0 *ne23), cu_data_type, ne0,
1927
1958
ne23,
1928
1959
cu_compute_type,
1929
1960
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1930
1961
}
1931
- #endif
1932
1962
1933
- if (dst->op_params [0 ] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) {
1934
- const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_F16);
1935
- to_fp32_cuda (dst_f16.get (), dst_ddf, ne_dst, main_stream);
1963
+ // Convert output back to F32 if needed
1964
+ if (dst->op_params [0 ] == GGML_PREC_DEFAULT && cu_data_type != CUDA_R_32F) {
1965
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (traits::ggml_type_val);
1966
+ to_fp32_cuda (dst_temp.get (), dst_ddf, ne_dst, main_stream);
1967
+ }
1968
+ }
1969
+
1970
+ static void ggml_cuda_mul_mat_batched_cublas (ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1971
+ GGML_ASSERT (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32);
1972
+
1973
+ switch (src0->type ) {
1974
+ case GGML_TYPE_F32:
1975
+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F32>(ctx, src0, src1, dst);
1976
+ break ;
1977
+ case GGML_TYPE_BF16:
1978
+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_BF16>(ctx, src0, src1, dst);
1979
+ break ;
1980
+ case GGML_TYPE_F16:
1981
+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F16>(ctx, src0, src1, dst);
1982
+ break ;
1983
+ default :
1984
+ GGML_ABORT (" Unsupported type" );
1936
1985
}
1937
1986
}
1938
1987
@@ -1984,6 +2033,12 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
1984
2033
// printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
1985
2034
// printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
1986
2035
2036
+ // TODO update for generic tensor parallelism
2037
+ const int cc = ggml_cuda_info ().devices [ggml_cuda_get_device ()].cc ;
2038
+ bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
2039
+ bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available (cc);
2040
+ bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
2041
+
1987
2042
if (!split && use_mul_mat_vec) {
1988
2043
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
1989
2044
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
@@ -1992,8 +2047,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
1992
2047
ggml_cuda_mul_mat_vec_q (ctx, src0, src1, nullptr , dst);
1993
2048
} else if (!split && use_mul_mat_q) {
1994
2049
ggml_cuda_mul_mat_q (ctx, src0, src1, nullptr , dst);
1995
- } else if (!split && src0-> type == GGML_TYPE_F16 && (src1-> type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
1996
- !ggml_is_transposed (src0) && !ggml_is_transposed (src1) && src1->ne [2 ]*src1->ne [3 ] > 1 ) {
2050
+ } else if (!split && (use_batched_cublas_f16 || use_batched_cublas_bf16 || use_batched_cublas_f32)
2051
+ && !ggml_is_transposed (src0) && !ggml_is_transposed (src1) && src1->ne [2 ]*src1->ne [3 ] > 1 ) {
1997
2052
// general KQ + KQV multi-batch without FlashAttention
1998
2053
ggml_cuda_mul_mat_batched_cublas (ctx, src0, src1, dst);
1999
2054
} else if (use_mul_mat_vec) {
0 commit comments