8000 CUDA: FA support for Deepseek (Ampere or newer) by JohannesGaessler · Pull Request #13306 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

CUDA: FA support for Deepseek (Ampere or newer) #13306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
wrap __cvta_generic_to_shared for HIP
  • Loading branch information
JohannesGaessler committed May 6, 2025
commit 187054a70c97f1444170ebceb2d8f6804cfd0ead
11 changes: 11 additions & 0 deletions ggml/src/ggml-cuda/cp-async.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,17 @@

#include "common.cuh"


static __device__ __forceinline__ unsigned int ggml_cuda_cvta_generic_to_shared(void * generic_ptr) {
#ifdef CP_ASYNC_AVAILABLE
return __cvta_generic_to_shared(generic_ptr);
#else
GGML_UNUSED(generic_ptr);
NO_DEVICE_CODE;
return 0;
#endif // CP_ASYNC_AVAILABLE
}
Comment on lines +6 to +14
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there is no fallback, why not avoid compiling the kernels that need this intrinsic in the first place?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In terms of development it's more convenient for me if potential breakage is encapsulated in an API such as this. That way, if I need to do a git bisect of my WIP commits later on there is less risk of having to deal with code that doesn't compile on specific hardware.


// Copies data from global to shared memory, cg == cache global.
// Both the src and dst pointers must be aligned to 16 bit.
// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.
Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml-cuda/fattn-mma-f16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
// The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.

if (use_cp_async) {
const unsigned int tile_KV_32 = __cvta_generic_to_shared(tile_KV);
const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);

constexpr int preload = 64;
constexpr int h2_per_chunk = 16/sizeof(half2);
Expand Down Expand Up @@ -186,7 +186,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa;
constexpr int stride_j = nwarps * cols_per_warp;

const unsigned int tile_mask_32 = __cvta_generic_to_shared(tile_mask);
const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask);

#pragma unroll
for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
Expand Down
Loading
0