8000 musa: enable MMA · ggml-org/llama.cpp@d91bdb3 · GitHub
[go: up one dir, main page]

Skip to content

Commit d91bdb3

Browse files
author
ZhouYu
committed
musa: enable MMA
1 parent 27aa259 commit d91bdb3

File tree

3 files changed

+53
-16
lines changed

3 files changed

+53
-16
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,10 @@ typedef float2 dfloat2;
215215
#define FP16_MMA_AVAILABLE
216216
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
217217

218+
#if defined(GGML_USE_MUSA) && !GGML_CUDA_MUSA_ARCH_IS_QY1
219+
#define FP16_MMA_AVAILABLE
220+
#endif // defined(GGML_USE_MUSA) && !GGML_CUDA_MUSA_ARCH_IS_QY1
221+
218222
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
219223
#define NEW_MMA_AVAILABLE
220224
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
@@ -237,7 +241,7 @@ static bool fast_fp16_available(const int cc) {
237241

238242
// To be used for feature selection of external libraries, e.g. cuBLAS.
239243
static bool fast_fp16_hardware_available(const int cc) {
240-
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
244+
return cc >= GGML_CUDA_CC_PASCAL && cc != 610 && cc != GGML_CUDA_CC_QY1;
241245
}
242246

243247
// Any FP16 tensor core instructions are available for ggml code.
@@ -246,13 +250,15 @@ static bool fp16_mma_available(const int cc) {
246250
return false;
247251
#else
248252
return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
253+
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2) ||
249254
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
250255
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
251256
}
252257

253258
// To be used for feature selection of external libraries, e.g. cuBLAS.
254259
static bool fp16_mma_hardware_available(const int cc) {
255260
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
261+
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2) ||
256262
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
257263
}
258264

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
#ifdef FP16_MMA_AVAILABLE
1010
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
1111
#include <mma.h>
12+
#ifdef GGML_USE_MUSA
13+
namespace wmma = mtmusa::wmma;
14+
#else // GGML_USE_MUSA
1215
namespace wmma = nvcuda::wmma;
16+
#endif // GGML_USE_MUSA
1317
#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
1418
#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
1519
#include <rocwmma/rocwmma.hpp>

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1865,13 +1865,24 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18651865
// use cublasGemmBatchedEx
18661866
const int64_t ne23 = ne12*ne13;
18671867

1868+
#ifdef GGML_USE_MUSA
1869+
const void ** ptrs_src;
1870+
void ** ptrs_dst;
1871+
CUDA_CHECK(cudaMalloc((void **)&ptrs_src, sizeof(void *)*2*ne23));
1872+
CUDA_CHECK(cudaMalloc((void **)&ptrs_dst, sizeof(void *)*1*ne23));
1873+
#else // GGML_USE_MUSA
18681874
ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
18691875
ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
1876+
#endif // GGML_USE_MUSA
18701877

18711878
dim3 block_dims(ne13, ne12);
18721879
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
18731880
src0_f16, src1_f16, dst_t,
1881+
#ifdef GGML_USE_MUSA
1882+
ptrs_src, ptrs_dst,
1883+
#else // GGML_USE_MUSA
18741884
ptrs_src.get(), ptrs_dst.get(),
1885+
#endif // GGML_USE_MUSA
18751886
ne12, ne13,
18761887
ne23,
18771888
nb02, nb03,
@@ -1881,15 +1892,31 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18811892
r2, r3);
18821893
CUDA_CHECK(cudaGetLastError());
18831894

1884-
CUBLAS_CHECK(
1895+
#ifdef GGML_USE_MUSA
1896+
cudaDeviceSynchronize();
1897+
const void **Aarray = (const void **) (ptrs_src + 0*ne23);
1898+
const void **Barray = (const void **) (ptrs_src + 1*ne23);
1899+
void **Carray = ( void **) (ptrs_dst + 0*ne23);
1900+
#else // GGML_USE_MUSA
1901+
const void **Aarray = (const void **) (ptrs_src.get() + 0*ne23);
1902+
const void **Barray = (const void **) (ptrs_src.get() + 1*ne23);
1903+
void **Carray = ( void **) (ptrs_dst.get() + 0*ne23);
1904+
#endif // GGML_USE_MUSA
1905+
1906+
CUBLAS_CHECK(
18851907
cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
18861908
ne01, ne11, ne10,
1887-
alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00,
1888-
(const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, s11,
1889-
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
1909+
alpha, Aarray, CUDA_R_16F, nb01/nb00,
1910+
Barray, CUDA_R_16F, s11,
1911+
beta, Carray, cu_data_type, ne0,
18901912
ne23,
18911913
cu_compute_type,
18921914
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1915+
1916+
#ifdef GGML_USE_MUSA
1917+
CUDA_CHECK(cudaFree(ptrs_src));
1918+
CUDA_CHECK(cudaFree(ptrs_dst));
1919+
#endif // GGML_USE_MUSA
18931920
}
18941921
#endif
18951922

@@ -2989,12 +3016,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
29893016
if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
29903017
return false;
29913018
}
2992-
#ifdef GGML_USE_MUSA
2993-
if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
2994-
!ggml_is_transposed(a) && !ggml_is_transposed(b)) {
2995-
return false;
2996-
}
2997-
#endif // GGML_USE_MUSA
3019+
// #ifdef GGML_USE_MUSA
3020+
// if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
3021+
// !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
3022+
// return false;
3023+
// }
3024+
// #endif // GGML_USE_MUSA
29983025
switch (a->type) {
29993026
case GGML_TYPE_F32:
30003027
case GGML_TYPE_F16:
@@ -3019,11 +3046,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
30193046
case GGML_TYPE_IQ4_NL:
30203047
case GGML_TYPE_IQ4_XS:
30213048
case GGML_TYPE_BF16:
3022-
#ifdef GGML_USE_MUSA
3023-
if (a->type == GGML_TYPE_Q3_K) {
3024-
return false;
3025-
}
3026-
#endif // GGML_USE_MUSA
3049+
// #ifdef GGML_USE_MUSA
3050+
// if (a->type == GGML_TYPE_Q3_K) {
3051+
// return false;
3052+
// }
3053+
// #endif // GGML_USE_MUSA
30273054
return true;
30283055
default:
30293056
return false;

0 commit comments

Comments
 (0)
0