8000 musa: enable muBlas and MMA · ggml-org/llama.cpp@f7ef983 · GitHub
[go: up one dir, main page]

Skip to content

Commit f7ef983

Browse files
committed
musa: enable muBlas and MMA
Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
1 parent 3398305 commit f7ef983

File tree

3 files changed

+85
-36
lines changed

3 files changed

+85
-36
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -76,17 +76,17 @@
7676
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
7777

7878
// Moore Threads
79-
#define GGML_CUDA_MUSA_ARCH_IS_QY1 (__MUSA_ARCH__ <= 210)
80-
81-
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
82-
#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
83-
#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD
79+
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
80+
#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
81+
#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD
8482

8583
#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
8684
#define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
8785
#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG)
8886
#define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG)
8987

88+
#define GGML_CUDA_CC_IS_QY1_OR_EARLIER (__MUSA_ARCH__ < 220)
89+
9090
#ifdef __CUDA_ARCH_LIST__
9191
constexpr bool ggml_cuda_has_arch_impl(int) {
9292
return false;
@@ -199,9 +199,9 @@ typedef float2 dfloat2;
199199
#define FP16_AVAILABLE
200200
#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
201201

202-
#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
202+
#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != GGML_CUDA_CC_DP4A
203203
#define FAST_FP16_AVAILABLE
204-
#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
204+
#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != GGML_CUDA_CC_DP4A
205205

206206
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
207207
#define FP16_MMA_AVAILABLE
@@ -211,6 +211,10 @@ typedef float2 dfloat2;
211211
#define FP16_MMA_AVAILABLE
212212
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
213213

214+
#if defined(GGML_USE_MUSA)
215+
#define FP16_MMA_AVAILABLE
216+
#endif // defined(GGML_USE_MUSA)
217+
214218
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
215219
#define NEW_MMA_AVAILABLE
216220
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
@@ -219,21 +223,22 @@ typedef float2 dfloat2;
219223
#define CP_ASYNC_AVAILABLE
220224
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
221225

222-
#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
226+
#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_CC_IS_QY1_OR_EARLIER)
223227
#define FLASH_ATTN_AVAILABLE
224-
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
228+
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_CC_IS_QY1_OR_EARLIER)
225229

226230
static bool fp16_available(const int cc) {
227231
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
228232
}
229233

230234
static bool fast_fp16_available(const int cc) {
231-
return (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
235+
return (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && cc != GGML_CUDA_CC_DP4A) || GGML_CUDA_CC_IS_AMD(cc);
232236
}
233237

234238
// To be used for feature selection of external libraries, e.g. cuBLAS.
235 6D4E 239
static bool fast_fp16_hardware_available(const int cc) {
236-
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
240+
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) ||
241+
GGML_CUDA_CC_IS_AMD(cc) || GGML_CUDA_CC_IS_MTHREADS(cc);
237242
}
238243

239244
// Any FP16 tensor core instructions are available for ggml code.
@@ -242,14 +247,16 @@ static bool fp16_mma_available(const int cc) {
242247
return false;
243248
#else
244249
return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
245-
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
250+
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc) ||
251+
GGML_CUDA_CC_IS_MTHREADS(cc);
246252
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
247253
}
248254

249255
// To be used for feature selection of external libraries, e.g. cuBLAS.
250256
static bool fp16_mma_hardware_available(const int cc) {
251257
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
252-
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
258+
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc) ||
259+
GGML_CUDA_CC_IS_MTHREADS(cc);
253260
}
254261

255262
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
#ifdef FP16_MMA_AVAILABLE
1010
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
1111
#include <mma.h>
12+
#ifdef GGML_USE_MUSA
13+
namespace wmma = mtmusa::wmma;
14+
#else // GGML_USE_MUSA
1215
namespace wmma = nvcuda::wmma;
16+
#endif // GGML_USE_MUSA
1317
#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
1418
#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
1519
#include <rocwmma/rocwmma.hpp>

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 61 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,7 +1200,9 @@ static void ggml_cuda_op_mul_mat_cublas(
12001200

12011201
const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
12021202

1203-
if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
1203+
if ((GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) ||
1204+
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2)) &&
1205+
src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
12041206
ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
12051207
if (src1->type != GGML_TYPE_BF16) {
12061208
const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type);
@@ -1228,7 +1230,9 @@ static void ggml_cuda_op_mul_mat_cublas(
12281230

12291231
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
12301232
to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff*src1_ncols, stream);
1231-
} else if (((GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) {
1233+
} else if (((GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
1234+
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2) ||
1235+
GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) {
12321236
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
12331237
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
12341238
if (src0->type != GGML_TYPE_F16) {
@@ -1872,13 +1876,24 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18721876
// use cublasGemmBatchedEx
18731877
const int64_t ne23 = ne12*ne13;
18741878

1879+
#ifdef GGML_USE_MUSA
1880+
const void ** ptrs_src;
1881+
void ** ptrs_dst;
1882+
CUDA_CHECK(cudaMalloc((void **)&ptrs_src, sizeof(void *)*2*ne23));
1883+
CUDA_CHECK(cudaMalloc((void **)&ptrs_dst, sizeof(void *)*1*ne23));
1884+
#else // GGML_USE_MUSA
18751885
ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
18761886
ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
1887+
#endif // GGML_USE_MUSA
18771888

18781889
dim3 block_dims(ne13, ne12);
18791890
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
18801891
src0_f16, src1_f16, dst_t,
1892+
#ifdef GGML_USE_MUSA
1893+
ptrs_src, ptrs_dst,
1894+
#else // GGML_USE_MUSA
18811895
ptrs_src.get(), ptrs_dst.get(),
1896+
#endif // GGML_USE_MUSA
18821897
ne12, ne13,
18831898
ne23,
18841899
nb02, nb03,
@@ -1888,15 +1903,31 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18881903
r2, r3);
18891904
CUDA_CHECK(cudaGetLastError());
18901905

1891-
CUBLAS_CHECK(
1906+
#ifdef GGML_USE_MUSA
1907+
CUDA_CHECK(cudaDeviceSynchronize());
1908+
const void **Aarray = (const void **) (ptrs_src + 0*ne23);
1909+
const void **Barray = (const void **) (ptrs_src + 1*ne23);
1910+
void **Carray = ( void **) (ptrs_dst + 0*ne23);
1911+
#else // GGML_USE_MUSA
1912+
const void **Aarray = (const void **) (ptrs_src.get() + 0*ne23);
1913+
const void **Barray = (const void **) (ptrs_src.get() + 1*ne23);
1914+
void **Carray = ( void **) (ptrs_dst.get() + 0*ne23);
1915+
#endif // GGML_USE_MUSA
1916+
1917+
CUBLAS_CHECK(
18921918
cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
18931919
ne01, ne11, ne10,
1894-
alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00,
1895-
(const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, s11,
1896-
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
1920+
alpha, Aarray, CUDA_R_16F, nb01/nb00,
1921+
Barray, CUDA_R_16F, s11,
1922+
beta, Carray, cu_data_type, ne0,
18971923
ne23,
18981924
cu_compute_type,
18991925
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1926+
1927+
#ifdef GGML_USE_MUSA
1928+
CUDA_CHECK(cudaFree(ptrs_src));
1929+
CUDA_CHECK(cudaFree(ptrs_dst));
1930+
#endif // GGML_USE_MUSA
19001931
}
19011932
#endif
19021933

@@ -1926,6 +1957,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19261957

19271958
bool any_gpus_with_slow_fp16 = false;
19281959
bool any_gpus_without_fp16_mma = false;
1960+
bool any_gpus_without_cublas_gemm = false;
19291961

19301962
if (split) {
19311963
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
@@ -1936,16 +1968,18 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19361968
continue;
19371969
}
19381970

1939-
const int cc = ggml_cuda_info().devices[id].cc;
1940-
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
1941-
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
1942-
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
1971+
const int cc = ggml_cuda_info().devices[id].cc;
1972+
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
1973+
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
1974+
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
1975+
any_gpus_without_cublas_gemm = any_gpus_without_cublas_gemm || !(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
19431976
}
19441977
} else {
1945-
const int cc = ggml_cuda_info().devices[ctx.device].cc;
1946-
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
1947-
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
1948-
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
1978+
const int cc = ggml_cuda_info().devices[ctx.device].cc;
1979+
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
1980+
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
1981+
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
1982+
any_gpus_without_cublas_gemm = any_gpus_without_cublas_gemm || !(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
19491983
}
19501984

19511985
// debug helpers
@@ -1964,8 +1998,9 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19641998
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
19651999
} else if (!split && use_mul_mat_q) {
19662000
ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst);
1967-
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
1968-
!ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
2001+
} else if (!split && !any_gpus_without_cublas_gemm && src0->type == GGML_TYPE_F16 &&
2002+
(src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
2003+
!ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
19692004
// general KQ + KQV multi-batch without FlashAttention
19702005
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
19712006
} else if (use_mul_mat_vec) {
@@ -3005,9 +3040,17 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
30053040
return false;
30063041
}
30073042
#ifdef GGML_USE_MUSA
3008-
if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
3043+
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
3044+
if (GGML_CUDA_CC_IS_MTHREADS(cc) && b->ne[2]*b->ne[3] > 1 &&
30093045
!ggml_is_transposed(a) && !ggml_is_transposed(b)) {
3010-
return false;
3046+
if (GGML_CUDA_CC_IS_QY1(cc) && op->op == GGML_OP_MUL_MAT &&
3047+
a->type == GGML_TYPE_F16 && b->type == GGML_TYPE_F16) {
3048+
return false;
3049+
}
3050+
if (GGML_CUDA_CC_IS_QY2(cc) && op->op == GGML_OP_MUL_MAT_ID &&
3051+
a->type == GGML_TYPE_Q2_K && b->type == GGML_TYPE_F32) {
3052+
return false;
3053+
}
30113054
}
30123055
#endif // GGML_USE_MUSA
30133056
switch (a->type) {
@@ -3034,11 +3077,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
30343077
case GGML_TYPE_IQ4_NL:
30353078
case GGML_TYPE_IQ4_XS:
30363079
case GGML_TYPE_BF16:
3037-
#ifdef GGML_USE_MUSA
3038-
if (a->type == GGML_TYPE_Q3_K) {
3039-
return false;
3040-
}
3041-
#endif // GGML_USE_MUSA
30423080
return true;
30433081
default:
30443082
return false;

0 commit comments

Comments
 (0)
0