8000 musa: add support for muBLAS and MMA by yeahdongcn · Pull Request #13149 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

musa: add support for muBLAS and MMA #13149

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed

Conversation

yeahdongcn
Copy link
Collaborator
@yeahdongcn yeahdongcn commented Apr 28, 2025

Make sure to read the contributing guidelines before submitting a PR

This PR adds support for muBLAS and MMA on Moore Threads GPU.

Two important notes:

  1. For MTT S80 (QY1) and earlier, muBLAS does not implement GEMM due to hardware limitations.
  2. muBLAS behaves differently from cuBLAS — specifically, arrays of pointers used in mublasGemmBatchedEx must be explicitly allocated on the GPU memory space and obtain valid GPU addresses.

Testing Done

  • Build completed successfully
  • ./build/bin/test-backend-ops passed on MTT S80 and MTT S4000
    root@991d9f0da970:/ws# ./build/bin/test-backend-ops 
    ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
    ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
    ggml_cuda_init: found 1 MUSA devices:
      Device 0: MTT S80, compute capability 2.1, VMM: yes
    Testing 2 devices
    
    Backend 1/2: MUSA0
      Device description: MTT S80
      Device memory: 16297 MB (15752 MB free)
    
      ABS(type=f16,ne_a=[128,2,2,2],v=0): OK
      ...
      CROSS_ENTROPY_LOSS(type=f32,ne=[10,5,4,3]): OK
      CROSS_ENTROPY_LOSS(type=f32,ne=[30000,1,1,1]): OK
      CROSS_ENTROPY_LOSS_BACK(type=f32,ne=[10,5,4,3]): OK
      CROSS_ENTROPY_LOSS_BACK(type=f32,ne=[30000,1,1,1]): OK
      OPT_STEP_ADAMW(type=f32,ne=[10,5,4,3]): OK
      5519/5519 tests passed
      Backend MUSA0: OK
    
    Backend 2/2: CPU
      Skipping CPU backend
    2/2 backends passed
    OK
    
    root@08e1fcc1c6e7:/ws# ./build/bin/test-backend-ops 
    ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
    ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
    ggml_cuda_init: found 1 MUSA devices:
      Device 0: MTT S4000, compute capability 2.2, VMM: yes
    Testing 2 devices
    
    Backend 1/2: MUSA0
      Device description: MTT S4000
      Device memory: 49061 MB (48400 MB free)
    
      ABS(type=f16,ne_a=[128,2,2,2],v=0): OK
      ...
      CROSS_ENTROPY_LOSS(type=f32,ne=[10,5,4,3]): OK
      CROSS_ENTROPY_LOSS(type=f32,ne=[30000,1,1,1]): OK
      CROSS_ENTROPY_LOSS_BACK(type=f32,ne=[10,5,4,3]): OK
      CROSS_ENTROPY_LOSS_BACK(type=f32,ne=[30000,1,1,1]): OK
      OPT_STEP_ADAMW(type=f32,ne=[10,5,4,3]): OK
      5519/5519 tests passed
      Backend MUSA0: OK
    
    Backend 2/2: CPU
      Skipping CPU backend
    2/2 backends passed
    OK
  • ./build/bin/llama-cli -m ~/models/qwen3_8b_q4_k_m.gguf -ngl 999 runs as expected on both MTT S80 and MTT S4000, with or without the -fa flag

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Apr 28, 2025
@yeahdongcn
Copy link
Collaborator Author
yeahdongcn commented Apr 28, 2025

So far, I can only get it working with -fa. Without -fa, I encounter either garbled characters in the LLM replies or repeated GGGGGGG....

@JohannesGaessler @slaren Could you please share some tips on how to debug this kind of issue? I'd appreciate it!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you manually allocating and deallocating memory instead of using ggml_cuda_pool_alloc? Batched FP16 GEMM is used for attention without FlashAttention so most likely this is where the bug is. I don't remember what the synchronization behavior of cudaFree is but if it's done asynchronously from the kernel executions that would explain why you get incorrect results.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late reply — I've been working closely with the MUSA SDK team to refine this.

Why are you manually allocating and deallocating memory instead of using ggml_cuda_pool_alloc?

This is because muBLAS behaves differently from cuBLAS: it requires that arrays of pointers be explicitly allocated in GPU memory and obtain valid GPU addresses. Using ggml_cuda_pool_alloc does not meet this requirement in this context.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please take a look at ggml_cuda_pool_leg. Before the current VMM implementation became the default that is how temporary memory buffers were assigned. It is still used if a GPU returns false for CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED. If muBLAS does not work correctly for VMM the GPU should return false for this property; or the MUSA backend should be compiled with GGML_CUDA_NO_VMM.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll double-check this. I don't recall if I’ve tested that option. Thanks for pointing it out!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After trying -DGGML_CUDA_NO_VMM=ON again, I remembered why I avoided it: compiling with this option results in garbled output during inference. I also attempted to modify the legacy pool to bypass pooling and use musaMalloc directly, but that didn’t work either.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright. For me the bottom line is this: I consider MUSA support to be comparatively low-priority. I am willing to review PRs and help with general development, but this is under the condition that it doesn't interfere with the rest of the code. I would not consider the current code in this PR acceptable in that regard. So I would ask you to either debug and fix the issues with the pool or to create a completely separate function like ggml_musa_mul_mat_mublas (with the understanding that maintenance of this function will be 100% your own responsibility). Relevant context: I'm currently working on vendor-agnostic multi GPU support, once that is available I intend to remove ggml_cuda_op_mul_mat and refactor the cuBLAS code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the clarification — that makes perfect sense. I agree the current state of the code isn't ideal. I'll go ahead and close this PR for now, take the time to sort out the necessary changes cleanly, and revisit it later with a more self-contained approach. Appreciate your feedback and support!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you intend to take the approach with a separate function? If yes, you should write it in a way analogous to e.g. ggml_cuda_mul_mat_vec (not to be confused with ggml_cuda_op_mul_mat_vec) where it's being called directly from ggml_cuda_mul_mat. As long as you implement support for non-contiguous tensors this is going to work correctly for all use cases except --split-mode row on master. Soon it will work correctly for all use cases once I've moved the logic for tensor parallelism out of the CUDA backend.

Copy link
Collaborator Author
@yeahdongcn yeahdongcn May 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late reply. I was testing with VMM disabled, and it turns out the issue I encountered was related to the kvcache (please see: #13788). Once I disabled it, everything worked as expected. So it seems I can stick with the legacy VMM implementation by default.

@JohannesGaessler
Copy link
Collaborator

Could you please share some tips on how to debug this kind of issue?

Run test-backend-ops -o MUL_MAT.

@yeahdongcn yeahdongcn force-pushed the xd/mma branch 2 times, most recently from 1428e18 to d91bdb3 Compare May 5, 2025 02:08
@github-actions github-actions bot added the testing Everything test related label May 5, 2025
@yeahdongcn yeahdongcn force-pushed the xd/mma branch 8 times, most recently from 24a7737 to 16aa155 Compare May 7, 2025 02:15
@yeahdongcn yeahdongcn changed the title musa: enable MMA musa: add support for muBLAS and MMA May 7, 2025
@yeahdongcn yeahdongcn force-pushed the xd/mma branch 2 times, most recently from 45b5578 to 4b4ba0f Compare May 8, 2025 06:33
@yeahdongcn
Copy link
Collaborator Author

This PR should depend on MUSA SDK version bump: #13647

@yeahdongcn yeahdongcn marked this pull request as ready for review May 20, 2025 11:16
@yeahdongcn
Copy link
Collaborator Author

Hi @JohannesGaessler Could you please also review this one? Thanks.

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
@yeahdongcn
Copy link
Collaborator Author

Rebased onto upstream/master and re-ran all the tests in the container, results remain unchanged.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please take a look at ggml_cuda_pool_leg. Before the current VMM implementation became the default that is how temporary memory buffers were assigned. It is still used if a GPU returns false for CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED. If muBLAS does not work correctly for VMM the GPU should return false for this property; or the MUSA backend should be compiled with GGML_CUDA_NO_VMM.

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants
0