8000 musa: use mudnn::Unary::IDENTITY op to accelerate D2D memory copy · makllama/llama.cpp@108017c · GitHub
[go: up one dir, main page]

Skip to content

Commit 108017c

Browse files
committed
musa: use mudnn::Unary::IDENTITY op to accelerate D2D memory copy
Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
1 parent 4a62030 commit 108017c

File tree

4 files changed

+146
-1
lines changed

4 files changed

+146
-1
lines changed

ggml/src/ggml-cuda/cpy.cu

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
#include "cpy.cuh"
22
#include "dequantize.cuh"
3+
#ifdef GGML_USE_MUSA
4+
#include "ggml-musa/mudnn.cuh"
5+
#endif // GGML_USE_MUSA
36

47
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
58

@@ -597,7 +600,15 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
597600
#endif
598601
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
599602
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
603+
#ifdef GGML_USE_MUSA
604+
if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) {
605+
CUDA_CHECK(mudnnMemcpyAsync(ctx, src1, src0));
606+
} else {
607+
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
608+
}
609+
#else // GGML_USE_MUSA
600610
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
611+
#endif // GGML_USE_MUSA
601612
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
602613
ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
603614
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {

ggml/src/ggml-musa/CMakeLists.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,15 @@ if (MUSAToolkit_FOUND)
2727

2828
file(GLOB GGML_HEADERS_MUSA "../ggml-cuda/*.cuh")
2929
list(APPEND GGML_HEADERS_MUSA "../../include/ggml-cuda.h")
30+
list(APPEND GGML_HEADERS_MUSA "../ggml-musa/mudnn.cuh")
3031

3132
file(GLOB GGML_SOURCES_MUSA "../ggml-cuda/*.cu")
3233
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
3334
list(APPEND GGML_SOURCES_MUSA ${SRCS})
3435
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
3536
list(APPEND GGML_SOURCES_MUSA ${SRCS})
37+
file(GLOB SRCS "../ggml-musa/*.cu")
38+
list(APPEND GGML_SOURCES_MUSA ${SRCS})
3639

3740
if (GGML_CUDA_FA_ALL_QUANTS)
3841
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*.cu")
@@ -92,9 +95,10 @@ if (MUSAToolkit_FOUND)
9295
endif()
9396

9497
8000 if (GGML_STATIC)
98+
# TODO: mudnn has not provided static libraries yet
9599
target_link_libraries(ggml-musa PRIVATE MUSA::musart_static MUSA::mublas_static)
96100
else()
97-
target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas)
101+
target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas mudnn)
98102
endif()
99103

100104
if (GGML_CUDA_NO_VMM)

ggml/src/ggml-musa/mudnn.cu

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
#include <mutex>
2+
#include <mudnn.h>
3+
4+
#include "mudnn.cuh"
5+
6+
namespace mudnn = musa::dnn;
7+
8+
// Returns a human-readable error string for mudnn::Status
9+
const char* mudnnGetErrorString(mudnn::Status err) {
10+
switch (err) {
11+
case mudnn::Status::SUCCESS:
12+
return "Success";
13+
case mudnn::Status::INVALID_PARAMETER:
14+
return "Invalid parameter";
15+
case mudnn::Status::NOT_INITIALIZED:
16+
return "Not initialized";
17+
case mudnn::Status::ALLOC_FAILED:
18+
return "Allocation failed";
19+
case mudnn::Status::NOT_SUPPORTED:
20+
return "Not supported";
21+
case mudnn::Status::INTERNAL_ERROR:
22+
return "Internal error";
23+
case mudnn::Status::ARCH_MISMATCH:
24+
return "Architecture mismatch";
25+
case mudnn::Status::EXECUTION_FAILED:
26+
return "Execution failed";
27+
default:
28+
return "Unknown mudnn status";
29+
}
30+
}
31+
32+
// Error checking macro for MUDNN calls
33+
#define MUDNN_CHECK_GEN(err, success, error_fn) \
34+
do { \
35+
auto err_ = (err); \
36+
if (err_ != (success)) { \
37+
ggml_cuda_error(#err, __func__, __FILE__, __LINE__, error_fn(err_)); \
38+
} \
39+
} while (0)
40+
41+
#define MUDNN_CHECK(err) MUDNN_CHECK_GEN(err, mudnn::Status::SUCCESS, mudnnGetErrorString)
42+
43+
namespace {
44+
// Thread-safe cache for mudnn::Handle objects per device
45+
std::unordered_map<int, std::unique_ptr<mudnn::Handle>> handle_cache;
46+
std::mutex handle_cache_mutex;
47+
48+
mudnn::Handle* get_cached_handle(int device_id) {
49+
std::lock_guard<std::mutex> lock(handle_cache_mutex);
50+
auto it = handle_cache.find(device_id);
51+
if (it != handle_cache.end()) {
52+
return it->second.get();
53+
}
54+
auto handle = std::make_unique<mudnn::Handle>(device_id);
55+
mudnn::Handle* handle_ptr = handle.get();
56+
handle_cache[device_id] = std::move(handle);
57+
return handle_ptr;
58+
}
59+
}
60+
61+
// Extracts dimensions and strides from a ggml_tensor
62+
int get_ggml_dims_and_strides(const ggml_tensor* tensor,
63+
std::vector<int64_t>& dims,
64+
std::vector<int64_t>& strides) {
65+
const int ndims = ggml_n_dims(tensor);
66+
const size_t element_size = ggml_element_size(tensor);
67+
68+
dims.resize(ndims);
69+
strides.resize(ndims);
70+
71+
for (int i = 0; i < ndims; ++i) {
72+
dims[i] = tensor->ne[i];
73+
strides[i] = tensor->nb[i] / static_cast<int64_t>(element_size);
74+
}
75+
return ndims;
76+
}
77+
78+
// Converts ggml_type to mudnn::Tensor::Type
79+
mudnn::Tensor::Type ggml_type_to_mudnn_type(ggml_type type) {
80+
switch (type) {
81+
case GGML_TYPE_F32:
82+
return mudnn::Tensor::Type::FLOAT;
83+
case GGML_TYPE_F16:
84+
return mudnn::Tensor::Type::HALF;
85+
86+
// TODO: Add support for other types
87+
88+
default:
89+
MUDNN_CHECK(mudnn::Status::NOT_SUPPORTED);
90+
}
91+
}
92+
93+
// Asynchronous memory copy using mudnn::Unary::IDENTITY
94+
musaError_t mudnnMemcpyAsync(ggml_backend_cuda_context& ctx, const ggml_tensor* dst, const ggml_tensor* src) {
95+
mudnn::Tensor tensor_dst, tensor_src;
96+
97+
MUDNN_CHECK(tensor_dst.SetType(ggml_type_to_mudnn_type(dst->type)));
98+
MUDNN_CHECK(tensor_src.SetType(ggml_type_to_mudnn_type(src->type)));
99+
100+
std::vector<int64_t> dims, strides;
101+
const int ndims = get_ggml_dims_and_strides(src, dims, strides);
102+
103+
MUDNN_CHECK(tensor_dst.SetNdInfo(ndims, dims.data(), strides.data()));
104+
MUDNN_CHECK(tensor_src.SetNdInfo(ndims, dims.data(), strides.data()));
105+
MUDNN_CHECK(tensor_dst.SetAddr(dst->data));
106+
MUDNN_CHECK(tensor_src.SetAddr(src->data));
107+
108+
mudnn::Unary op;
109+
MUDNN_CHECK(op.SetMode(mudnn::Unary::Mode::IDENTITY));
110+
MUDNN_CHECK(op.SetAlpha(0.0f));
111+
MUDNN_CHECK(op.SetBeta(0.0f));
112+
113+
mudnn::Handle* handle = get_cached_handle(ctx.device);
114+
MUDNN_CHECK(handle->SetStream(ctx.stream()));
115+
MUDNN_CHECK(op.Run(*handle, tensor_dst, tensor_src));
116+
117+
return musaSuccess;
118+
}

ggml/src/ggml-musa/mudnn.cuh

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#pragma once
2+
3+
#include "../include/ggml.h"
4+
#include "../ggml-cuda/common.cuh"
5+
6+
// Asynchronously copies data from src tensor to dst tensor using the provided context.
7+
// Returns a musaError_t indicating success or failure.
8+
musaError_t mudnnMemcpyAsync(
9+
ggml_backend_cuda_context &ctx,
10+
const ggml_tensor *dst,
11+
const ggml_tensor *src
12+
);

0 commit comments

Comments
 (0)
0