8000 musa: Upgrade MUSA SDK version to rc4.0.1 and use mudnn::Unary::IDENT… · ochafik/llama.cpp@3398305 · GitHub
[go: up one dir, main page]

Skip to content
8000

Commit 3398305

Browse files
musa: Upgrade MUSA SDK version to rc4.0.1 and use mudnn::Unary::IDENTITY op to accelerate D2D memory copy (ggml-org#13647)
* musa: fix build warning (unused parameter) Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * musa: upgrade MUSA SDK version to rc4.0.1 Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * musa: use mudnn::Unary::IDENTITY op to accelerate D2D memory copy Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * Update ggml/src/ggml-cuda/cpy.cu Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * musa: remove MUDNN_CHECK_GEN and use CUDA_CHECK_GEN instead in MUDNN_CHECK Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> --------- Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
1 parent fb1cab2 commit 3398305

File tree

10 files changed

+153
-20
lines changed

10 files changed

+153
-20
lines changed

.devops/musa.Dockerfile

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
ARG UBUNTU_VERSION=22.04
22
# This needs to generally match the container host's environment.
3-
ARG MUSA_VERSION=rc3.1.1
3+
ARG MUSA_VERSION=rc4.0.1
44
# Target the MUSA build image
5-
ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
5+
ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-mudnn-devel-ubuntu${UBUNTU_VERSION}
66

7-
ARG BASE_MUSA_RUN_CONTAINER=mthreads/musa:${MUSA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}
7+
ARG BASE_MUSA_RUN_CONTAINER=mthreads/musa:${MUSA_VERSION}-mudnn-runtime-ubuntu${UBUNTU_VERSION}
88

99
FROM ${BASE_MUSA_DEV_CONTAINER} AS build
1010

@@ -21,21 +21,14 @@ RUN apt-get update && \
2121
libcurl4-openssl-dev \
2222
libgomp1
2323

24-
COPY requirements.txt requirements.txt
25-
COPY requirements requirements
26-
27-
RUN pip install --upgrade pip setuptools wheel \
28-
&& pip install -r requirements.txt
29-
3024
WORKDIR /app
3125

3226
COPY . .
3327

34-
# Use the default MUSA archs if not specified
3528
RUN if [ "${MUSA_DOCKER_ARCH}" != "default" ]; then \
3629
export CMAKE_ARGS="-DMUSA_ARCHITECTURES=${MUSA_DOCKER_ARCH}"; \
3730
fi && \
38-
cmake -B build -DGGML_NATIVE=OFF -DGGML_MUSA=ON -DLLAMA_BUILD_TESTS=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \
31+
cmake -B build -DGGML_NATIVE=OFF -DGGML_MUSA=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DLLAMA_BUILD_TESTS=OFF ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \
3932
cmake --build build --config Release -j$(nproc)
4033

4134
RUN mkdir -p /app/lib && \

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ jobs:
351351
352352
ubuntu-22-cmake-musa:
353353
runs-on: ubuntu-22.04
354-
container: mthreads/musa:rc3.1.1-devel-ubuntu22.04
354+
container: mthreads/musa:rc4.0.1-mudnn-devel-ubuntu22.04
355355

356356
steps:
357357
- name: Clone

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ range of hardware - locally and in the cloud.
3737
- Apple silicon is a first-class citizen - optimized via ARM NEON, Accelerate and Metal frameworks
3838
- AVX, AVX2, AVX512 and AMX support for x86 architectures
3939
- 1.5-bit, 2-bit, 3-bit, 4-bit, 5-bit, 6-bit, and 8-bit integer quantization for faster inference and reduced memory use
40-
- Custom CUDA kernels for running LLMs on NVIDIA GPUs (support for AMD GPUs via HIP and Moore Threads MTT GPUs via MUSA)
40+
- Custom CUDA kernels for running LLMs on NVIDIA GPUs (support for AMD GPUs via HIP and Moore Threads GPUs via MUSA)
4141
- Vulkan and SYCL backend support
4242
- CPU+GPU hybrid inference to partially accelerate models larger than the total VRAM capacity
4343

@@ -237,7 +237,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
237237
| [BLAS](docs/build.md#blas-build) | All |
238238
| [BLIS](docs/backend/BLIS.md) | All |
239239
| [SYCL](docs/backend/SYCL.md) | Intel and Nvidia GPU |
240-
| [MUSA](docs/build.md#musa) | Moore Threads MTT GPU |
240+
| [MUSA](docs/build.md#musa) | Moore Threads GPU |
241241
| [CUDA](docs/build.md#cuda) | Nvidia GPU |
242242
| [HIP](docs/build.md#hip) | AMD GPU |
243243
| [Vulkan](docs/build.md#vulkan) | GPU |

ci/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ docker run --privileged -it \
5454
-v $HOME/llama.cpp/ci-cache:/ci-cache \
5555
-v $HOME/llama.cpp/ci-results:/ci-results \
5656
-v $PWD:/ws -w /ws \
57-
mthreads/musa:rc3.1.1-devel-ubuntu22.04
57+
mthreads/musa:rc4.0.1-mudnn-devel-ubuntu22.04
5858
```
5959

6060
Inside the container, execute the following commands:

docs/docker.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ You may want to pass in some different `ARGS`, depending on the MUSA environment
107107

108108
The defaults are:
109109

110-
- `MUSA_VERSION` set to `rc3.1.1`
110+
- `MUSA_VERSION` set to `rc4.0.1`
111111

112112
The resulting images, are essentially the same as the non-MUSA images:
113113

ggml/src/ggml-cuda/cpy.cu

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
#include "cpy.cuh"
22
#include "dequantize.cuh"
3+
#ifdef GGML_USE_MUSA
4+
#include "ggml-musa/mudnn.cuh"
5+
#endif // GGML_USE_MUSA
36

47
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
58

@@ -597,7 +600,14 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
597600
#endif
598601
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
599602
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
600-
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
603+
#ifdef GGML_USE_MUSA
604+
if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) {
605+
CUDA_CHECK(mudnnMemcpyAsync(ctx, src1, src0));
606+
} else
607+
#endif // GGML_USE_MUSA
608+
{
609+
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
610+
}
601611
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
602612
ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
603613
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -772,7 +772,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
772772
GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
773773
GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
774774
GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
775-
GGML_UNUSED(kb0);
775+
GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
776776
NO_DEVICE_CODE;
777777
#endif // NEW_MMA_AVAILABLE
778778
}

ggml/src/ggml-musa/CMakeLists.txt

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,15 @@ if (MUSAToolkit_FOUND)
2727

2828
file(GLOB GGML_HEADERS_MUSA "../ggml-cuda/*.cuh")
2929
list(APPEND GGML_HEADERS_MUSA "../../include/ggml-cuda.h")
30+
list(APPEND GGML_HEADERS_MUSA "../ggml-musa/mudnn.cuh")
3031

3132
file(GLOB GGML_SOURCES_MUSA "../ggml-cuda/*.cu")
3233
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
3334
list(APPEND GGML_SOURCES_MUSA ${SRCS})
3435
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
3536
list(APPEND GGML_SOURCES_MUSA ${SRCS})
37+
file(GLOB SRCS "../ggml-musa/*.cu")
38+
list(APPEND GGML_SOURCES_MUSA ${SRCS})
3639

3740
if (GGML_CUDA_FA_ALL_QUANTS)
3841
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*.cu")
@@ -62,7 +65,9 @@ if (MUSAToolkit_FOUND)
6265
)
6366

6467
# TODO: do not use CUDA definitions for MUSA
65-
target_compile_definitions(ggml PUBLIC GGML_USE_CUDA)
68+
if (NOT GGML_BACKEND_DL)
69+
target_compile_definitions(ggml PUBLIC GGML_USE_CUDA)
70+
endif()
6671

6772
add_compile_definitions(GGML_USE_MUSA)
6873
add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
@@ -92,9 +97,10 @@ if (MUSAToolkit_FOUND)
9297
endif()
9398

9499
if (GGML_STATIC)
100+
# TODO: mudnn has not provided static libraries yet
95101
target_link_libraries(ggml-musa PRIVATE MUSA::musart_static MUSA::mublas_static)
96102
else()
97-
target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas)
103+
target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas mudnn)
98104
endif() F987
99105

100106
if (GGML_CUDA_NO_VMM)

ggml/src/ggml-musa/mudnn.cu

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
#include <mutex>
2+
#include <mudnn.h>
3+
4+
#include "mudnn.cuh"
5+
6+
namespace mudnn = musa::dnn;
7+
8+
// Returns a human-readable error string for mudnn::Status
9+
const char* mudnnGetErrorString(mudnn::Status err) {
10+
switch (err) {
11+
case mudnn::Status::SUCCESS:
12+
return "Success";
13+
case mudnn::Status::INVALID_PARAMETER:
14+
return "Invalid parameter";
15+
case mudnn::Status::NOT_INITIALIZED:
16+
return "Not initialized";
17+
case mudnn::Status::ALLOC_FAILED:
18+
return "Allocation failed";
19+
case mudnn::Status::NOT_SUPPORTED:
20+
return "Not supported";
21+
case mudnn::Status::INTERNAL_ERROR:
22+
return "Internal error";
23+
case mudnn::Status::ARCH_MISMATCH:
24+
return "Architecture mismatch";
25+
case mudnn::Status::EXECUTION_FAILED:
26+
return "Execution failed";
27+
default:
28+
return "Unknown mudnn status";
29+
}
30+
}
31+
32+
// Error checking macro for MUDNN calls
33+
#define MUDNN_CHECK(err) CUDA_CHECK_GEN(err, mudnn::Status::SUCCESS, mudnnGetErrorString)
34+
35+
namespace {
36+
// Thread-safe cache for mudnn::Handle objects per device
37+
std::unordered_map<int, std::unique_ptr<mudnn::Handle>> handle_cache;
38+
std::mutex handle_cache_mutex;
39+
40+
mudnn::Handle* get_cached_handle(int device_id) {
41+
std::lock_guard<std::mutex> lock(handle_cache_mutex);
42+
auto it = handle_cache.find(device_id);
43+
if (it != handle_cache.end()) {
44+
return it->second.get();
45+
}
46+
auto handle = std::make_unique<mudnn::Handle>(device_id);
47+
mudnn::Handle* handle_ptr = handle.get();
48+
handle_cache[device_id] = std::move(handle);
49+
return handle_ptr;
50+
}
51+
}
52+
53+
// Extracts dimensions and strides from a ggml_tensor
54+
int get_ggml_dims_and_strides(const ggml_tensor* tensor,
55+
std::vector<int64_t>& dims,
56+
std::vector<int64_t>& strides) {
57+
const int ndims = ggml_n_dims(tensor);
58+
const size_t element_size = ggml_element_size(tensor);
59+
60+
dims.resize(ndims);
61+
strides.resize(ndims);
62+
63+
for (int i = 0; i < ndims; ++i) {
64+
dims[i] = tensor->ne[i];
65+
strides[i] = tensor->nb[i] / static_cast<int64_t>(element_size);
66+
}
67+
return ndims;
68+
}
69+
70+
// Converts ggml_type to mudnn::Tensor::Type
71+
mudnn::Tensor::Type ggml_type_to_mudnn_type(ggml_type type) {
72+
switch (type) {
73+
case GGML_TYPE_F32:
74+
return mudnn::Tensor::Type::FLOAT;
75+
case GGML_TYPE_F16:
76+
return mudnn::Tensor::Type::HALF;
77+
78+
// TODO: Add support for other types
79+
80+
default:
81+
MUDNN_CHECK(mudnn::Status::NOT_SUPPORTED);
82+
}
83+
84+
return mudnn::Tensor::Type::FLOAT; // Default fallback
85+
}
86+
87+
// Asynchronous memory copy using mudnn::Unary::IDENTITY
88+
musaError_t mudnnMemcpyAsync(ggml_backend_cuda_context& ctx, const ggml_tensor* dst, const ggml_tensor* src) {
89+
mudnn::Tensor tensor_dst, tensor_src;
90+
91+
MUDNN_CHECK(tensor_dst.SetType(ggml_type_to_mudnn_type(dst->type)));
92+
MUDNN_CHECK(tensor_src.SetType(ggml_type_to_mudnn_type(src->type)));
93+
94+
std::vector<int64_t> dims, strides;
95+
const int ndims = get_ggml_dims_and_strides(src, dims, strides);
96+
97+
MUDNN_CHECK(tensor_dst.SetNdInfo(ndims, dims.data(), strides.data()));
98+
MUDNN_CHECK(tensor_src.SetNdInfo(ndims, dims.data(), strides.data()));
99+
MUDNN_CHECK(tensor_dst.SetAddr(dst->data));
100+
MUDNN_CHECK(tensor_src.SetAddr(src->data));
101+
102+
mudnn::Unary op;
103+
MUDNN_CHECK(op.SetMode(mudnn::Unary::Mode::IDENTITY));
104+
MUDNN_CHECK(op.SetAlpha(0.0f));
105+
MUDNN_CHECK(op.SetBeta(0.0f));
106+
107+
mudnn::Handle* handle = get_cached_handle(ctx.device);
108+
MUDNN_CHECK(handle->SetStream(ctx.stream()));
109+
MUDNN_CHECK(op.Run(*handle, tensor_dst, tensor_src));
110+
111+
return musaSuccess;
112+
}

ggml/src/ggml-musa/mudnn.cuh

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#pragma once
2+
3+
#include "../include/ggml.h"
4+
#include "../ggml-cuda/common.cuh"
5+
6+
// Asynchronously copies data from src tensor to dst tensor using the provided context.
7+
// Returns a musaError_t indicating success or failure.
8+
musaError_t mudnnMemcpyAsync(
9+
ggml_backend_cuda_context &ctx,
10+
const ggml_tensor *dst,
11+
const ggml_tensor *src
12+
);

0 commit comments

Comments
 (0)
0