@@ -1200,7 +1200,9 @@ static void ggml_cuda_op_mul_mat_cublas(
1200
1200
1201
1201
const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized (src0->type )) && ggml_is_contiguous (src0) && row_diff == src0->ne [1 ] && dst->op_params [0 ] == GGML_PREC_DEFAULT;
1202
1202
1203
- if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous (src0) && row_diff == src0->ne [1 ]) {
1203
+ if ((GGML_CUDA_CC_IS_NVIDIA (cc) || GGML_CUDA_CC_IS_AMD (cc) ||
1204
+ (GGML_CUDA_CC_IS_MTHREADS (cc) && cc >= GGML_CUDA_CC_QY2)) &&
1205
+ src0->type == GGML_TYPE_BF16 && ggml_is_contiguous (src0) && row_diff == src0->ne [1 ]) {
1204
1206
ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16 (ctx.pool (id));
1205
1207
if (src1->type != GGML_TYPE_BF16) {
1206
1208
const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda (src1->type );
@@ -1228,7 +1230,9 @@ static void ggml_cuda_op_mul_mat_cublas(
1228
1230
1229
1231
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_BF16);
1230
1232
to_fp32_cuda (dst_bf16.get (), dst_dd_i, row_diff*src1_ncols, stream);
1231
- } else if (((GGML_CUDA_CC_IS_NVIDIA (cc) && cc >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_AMD (cc)) && use_fp16) {
1233
+ } else if (((GGML_CUDA_CC_IS_NVIDIA (cc) && cc >= GGML_CUDA_CC_VOLTA) ||
1234
+ (GGML_CUDA_CC_IS_MTHREADS (cc) && cc >= GGML_CUDA_CC_QY2) ||
1235
+ GGML_CUDA_CC_IS_AMD (cc)) && use_fp16) {
1232
1236
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
1233
1237
ggml_cuda_pool_alloc<half> src0_as_f16 (ctx.pool (id));
1234
1238
if (src0->type != GGML_TYPE_F16) {
@@ -1872,13 +1876,24 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1872
1876
// use cublasGemmBatchedEx
1873
1877
const int64_t ne23 = ne12*ne13;
1874
1878
1879
+ #ifdef GGML_USE_MUSA
1880
+ const void ** ptrs_src;
1881
+ void ** ptrs_dst;
1882
+ CUDA_CHECK (cudaMalloc ((void **)&ptrs_src, sizeof (void *)*2 *ne23));
1883
+ CUDA_CHECK (cudaMalloc ((void **)&ptrs_dst, sizeof (void *)*1 *ne23));
1884
+ #else // GGML_USE_MUSA
1875
1885
ggml_cuda_pool_alloc<const void *> ptrs_src (ctx.pool (), 2 *ne23);
1876
1886
ggml_cuda_pool_alloc< void *> ptrs_dst (ctx.pool (), 1 *ne23);
1887
+ #endif // GGML_USE_MUSA
1877
1888
1878
1889
dim3 block_dims (ne13, ne12);
1879
1890
k_compute_batched_ptrs<<<1 , block_dims, 0 , main_stream>>> (
1880
1891
src0_f16, src1_f16, dst_t ,
1892
+ #ifdef GGML_USE_MUSA
1893
+ ptrs_src, ptrs_dst,
1894
+ #else // GGML_USE_MUSA
1881
1895
ptrs_src.get (), ptrs_dst.get (),
1896
+ #endif // GGML_USE_MUSA
1882
1897
ne12, ne13,
1883
1898
ne23,
1884
1899
nb02, nb03,
@@ -1888,15 +1903,31 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1888
1903
r2, r3);
1889
1904
CUDA_CHECK (cudaGetLastError ());
1890
1905
1891
- CUBLAS_CHECK (
1906
+ #ifdef GGML_USE_MUSA
1907
+ CUDA_CHECK (cudaDeviceSynchronize ());
1908
+ const void **Aarray = (const void **) (ptrs_src + 0 *ne23);
1909
+ const void **Barray = (const void **) (ptrs_src + 1 *ne23);
1910
+ void **Carray = ( void **) (ptrs_dst + 0 *ne23);
1911
+ #else // GGML_USE_MUSA
1912
+ const void **Aarray = (const void **) (ptrs_src.get () + 0 *ne23);
1913
+ const void **Barray = (const void **) (ptrs_src.get () + 1 *ne23);
1914
+ void **Carray = ( void **) (ptrs_dst.get () + 0 *ne23);
1915
+ #endif // GGML_USE_MUSA
1916
+
1917
+ CUBLAS_CHECK (
1892
1918
cublasGemmBatchedEx (ctx.cublas_handle (), CUBLAS_OP_T, CUBLAS_OP_N,
1893
1919
ne01, ne11, ne10,
1894
- alpha, ( const void **) (ptrs_src. get () + 0 *ne23) , CUDA_R_16F, nb01/nb00,
1895
- ( const void **) (ptrs_src. get () + 1 *ne23) , CUDA_R_16F, s11,
1896
- beta, ( void **) (ptrs_dst. get () + 0 *ne23) , cu_data_type, ne0,
1920
+ alpha, Aarray , CUDA_R_16F, nb01/nb00,
1921
+ Barray , CUDA_R_16F, s11,
1922
+ beta, Carray , cu_data_type, ne0,
1897
1923
ne23,
1898
1924
cu_compute_type,
1899
1925
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1926
+
1927
+ #ifdef GGML_USE_MUSA
1928
+ CUDA_CHECK (cudaFree (ptrs_src));
1929
+ CUDA_CHECK (cudaFree (ptrs_dst));
1930
+ #endif // GGML_USE_MUSA
1900
1931
}
1901
1932
#endif
1902
1933
@@ -1926,6 +1957,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
1926
1957
1927
1958
bool any_gpus_with_slow_fp16 = false ;
1928
1959
bool any_gpus_without_fp16_mma = false ;
1960
+ bool any_gpus_without_cublas_gemm = false ;
1929
1961
1930
1962
if (split) {
1931
1963
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer ->buft ->context ;
@@ -1936,16 +1968,18 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
1936
1968
continue ;
1937
1969
}
1938
1970
1939
- const int cc = ggml_cuda_info ().devices [id].cc ;
1940
- use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq (src0->type , cc, src1->ne [1 ]);
1941
- any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available (cc);
1942
- any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available (cc);
1971
+ const int cc = ggml_cuda_info ().devices [id].cc ;
1972
+ use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq (src0->type , cc, src1->ne [1 ]);
1973
+ any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available (cc);
1974
+ any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available (cc);
1975
+ any_gpus_without_cublas_gemm = any_gpus_without_cublas_gemm || !(GGML_CUDA_CC_IS_MTHREADS (cc) && cc >= GGML_CUDA_CC_QY2);
1943
1976
}
1944
1977
} else {
1945
- const int cc = ggml_cuda_info ().devices [ctx.device ].cc ;
1946
- use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq (src0->type , cc, src1->ne [1 ]);
1947
- any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available (cc);
1948
- any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available (cc);
1978
+ const int cc = ggml_cuda_info ().devices [ctx.device ].cc ;
1979
+ use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq (src0->type , cc, src1->ne [1 ]);
1980
+ any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available (cc);
1981
+ any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available (cc);
1982
+ any_gpus_without_cublas_gemm = any_gpus_without_cublas_gemm || !(GGML_CUDA_CC_IS_MTHREADS (cc) && cc >= GGML_CUDA_CC_QY2);
1949
1983
}
1950
1984
1951
1985
// debug helpers
@@ -1964,8 +1998,9 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
1964
1998
ggml_cuda_mul_mat_vec_q (ctx, src0, src1, nullptr , dst);
1965
1999
} else if (!split && use_mul_mat_q) {
1966
2000
ggml_cuda_mul_mat_q (ctx, src0, src1, nullptr , dst);
1967
- } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
1968
- !ggml_is_transposed (src0) && !ggml_is_transposed (src1) && src1->ne [2 ]*src1->ne [3 ] > 1 ) {
2001
+ } else if (!split && !any_gpus_without_cublas_gemm && src0->type == GGML_TYPE_F16 &&
2002
+ (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
2003
+ !ggml_is_transposed (src0) && !ggml_is_transposed (src1) && src1->ne [2 ]*src1->ne [3 ] > 1 ) {
1969
2004
// general KQ + KQV multi-batch without FlashAttention
1970
2005
ggml_cuda_mul_mat_batched_cublas (ctx, src0, src1, dst);
1971
2006
} else if (use_mul_mat_vec) {
@@ -3005,9 +3040,17 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3005
3040
return false ;
3006
3041
}
3007
3042
#ifdef GGML_USE_MUSA
3008
- if (b->type == GGML_TYPE_F16 && b->ne [2 ]*b->ne [3 ] > 1 &&
3043
+ const int cc = ggml_cuda_info ().devices [dev_ctx->device ].cc ;
3044
+ if (GGML_CUDA_CC_IS_MTHREADS (cc) && b->ne [2 ]*b->ne [3 ] > 1 &&
3009
3045
!ggml_is_transposed (a) && !ggml_is_transposed (b)) {
3010
- return false ;
3046
+ if (GGML_CUDA_CC_IS_QY1 (cc) && op->op == GGML_OP_MUL_MAT &&
3047
+ a->type == GGML_TYPE_F16 && b->type == GGML_TYPE_F16) {
3048
+ return false ;
3049
+ }
3050
+ if (GGML_CUDA_CC_IS_QY2 (cc) && op->op == GGML_OP_MUL_MAT_ID &&
3051
+ a->type == GGML_TYPE_Q2_K && b->type == GGML_TYPE_F32) {
3052
+ return false ;
3053
+ }
3011
3054
}
3012
3055
#endif // GGML_USE_MUSA
3013
3056
switch (a->type ) {
@@ -3034,11 +3077,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3034
3077
case GGML_TYPE_IQ4_NL:
3035
3078
case GGML_TYPE_IQ4_XS:
3036
3079
case GGML_TYPE_BF16:
3037
- #ifdef GGML_USE_MUSA
3038
- if (a->type == GGML_TYPE_Q3_K) {
3039
- return false ;
3040
- }
3041
- #endif // GGML_USE_MUSA
3042
3080
return true ;
3043
3081
default :
3044
3082
return false ;
0 commit comments