8000 CUDA: skip fully masked-out KV in FA vec kernel by JohannesGaessler · Pull Request #13584 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

CUDA: skip fully masked-out KV in FA vec kernel #13584

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 48 additions & 4 deletions ggml/src/ggml-cuda/fattn-vec-f16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
#include "fattn-common.cuh"

template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#ifndef GGML_USE_HIP
__launch_bounds__(D, 1)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#endif // GGML_USE_HIP
static __global__ void flash_attn_vec_ext_f16(
const char * __restrict__ Q,
const char * __restrict__ K,
Expand Down Expand Up @@ -48,6 +48,12 @@ static __global__ void flash_attn_vec_ext_f16(
NO_DEVICE_CODE;
return;
}
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
if (ncols > 1) {
NO_DEVICE_CODE;
return;
}
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)

//In this kernel Q, K, V are matrices while i, j, k are matrix indices.

Expand Down Expand Up @@ -91,6 +97,13 @@ static __global__ void flash_attn_vec_ext_f16(
kqsum_shared[j][threadIdx.x] = 0.0f;
}
}

__shared__ half maskh_shared[ncols*D];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
maskh_shared[j*D + tid] = 0.0f;
}

__syncthreads();

// Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers:
Expand Down Expand Up @@ -175,6 +188,35 @@ static __global__ void flash_attn_vec_ext_f16(
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
// Calculate KQ tile and keep track of new maximum KQ values:

if (mask) {
#pragma unroll
for (int j = 0; j < ncols; ++j) {
maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + k_VKQ_0 + tid];
}
Comment on lines +193 to +195
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JohannesGaessler I think this is missing an inner loop over D?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, because for this kernel the number of threads is equal to D.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm looking into the bug report #13733 and it seems there is problem with the change in this PR.

I can reproduce with this command:

./bin/llama-parallel -hf ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF -np 2 -ns 32 --top-k 1 --junk 131 -c 16384 -fa

This will output repetitive junk.

If apply this patch to disable the skipping logic, it works:

diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh
index 49c592ea5..b141c233c 100644
--- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh
+++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh
@@ -217,7 +217,7 @@ static __global__ void flash_attn_vec_ext_f32(
                 }
             }
             if (__all_sync(0xFFFFFFFF, skip)) {
-                continue;
+                //continue;
             }
 #endif // GGML_USE_HIP
         }

Any ideas what could be wrong?


__syncthreads();

// When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out.
// In such cases, skip the KV slice.
// On AMD __all_sync would not work correctly because it assumes a warp size of 64.
#ifndef GGML_USE_HIP
bool skip = true;
#pragma unroll
for (int j = 0; j < ncols; ++j) {
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;

const float2 tmp = __half22float2(((const half2 *) maskh_shared)[j*(D/2) + i]);
skip = skip && isinf(tmp.x) && isinf(tmp.y);
}
}
if (__all_sync(0xFFFFFFFF, skip)) {
continue;
}
#endif // GGML_USE_HIP
}

// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
// see https://github.com/ggerganov/llama.cpp/pull/7061 .
// Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable).
Expand Down Expand Up @@ -202,7 +244,7 @@ static __global__ void flash_attn_vec_ext_f16(
sum = logit_softcap*tanhf(sum);
}

sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
sum += maskh_shared[j*D + i_KQ];

if (ncols == 1) {
kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
Expand Down Expand Up @@ -335,7 +377,9 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));

if (Q->ne[1] == 1) {
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;

if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) {
constexpr int cols_per_block = 1;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
Expand Down
51 changes: 47 additions & 4 deletions ggml/src/ggml-cuda/fattn-vec-f32.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
#include "fattn-common.cuh"

template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#ifndef GGML_USE_HIP
__launch_bounds__(D, 1)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#endif // GGML_USE_HIP
static __global__ void flash_attn_vec_ext_f32(
const char * __restrict__ Q,
const char * __restrict__ K,
Expand Down Expand Up @@ -60,6 +60,12 @@ static __global__ void flash_attn_vec_ext_f32(
NO_DEVICE_CODE;
return;
}
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
if (ncols > 1) {
NO_DEVICE_CODE;
return;
}
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)

//In this kernel Q, K, V are matrices while i, j, k are matrix indices.

Expand Down Expand Up @@ -104,6 +110,13 @@ static __global__ void flash_attn_vec_ext_f32(
kqsum_shared[j][threadIdx.x] = 0.0f;
}
}

__shared__ float maskf_shared[ncols*D];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
maskf_shared[j*D + tid] = 0.0f;
}

__syncthreads();

// Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
Expand Down Expand Up @@ -181,6 +194,34 @@ static __global__ void flash_attn_vec_ext_f32(
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
// Calculate KQ tile and keep track of new maximum KQ values:

if (mask) {
#pragma unroll
for (int j = 0; j < ncols; ++j) {
maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + k_VKQ_0 + tid]);
}

__syncthreads();

// When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out.
// In such cases, skip the KV slice.
// On AMD __all_sync would not work correctly because it assumes a warp size of 64.
#ifndef GGML_USE_HIP
bool skip = true;
#pragma unroll
for (int j = 0; j < ncols; ++j) {
#pragma unroll
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;

skip = skip && isinf(maskf_shared[j*D + i]);
}
}
if (__all_sync(0xFFFFFFFF, skip)) {
continue;
}
Comment on lines +219 to +221
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need one more sync here:

            if (__all_sync(0xFFFFFFFF, skip)) {
                __syncthreads();
                continue;
            }

This change fixes it for me.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think you're right. I initially only read your other comment above but when I looked at the code again myself I came to the same conclusion.

#endif // GGML_USE_HIP
}

float kqmax_new_arr[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
Expand All @@ -204,7 +245,7 @@ static __global__ void flash_attn_vec_ext_f32(
sum = logit_softcap*tanhf(sum);
}

sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
sum += maskf_shared[j*D + i_KQ];

kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);

Expand Down Expand Up @@ -326,7 +367,9 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));

if (Q->ne[1] == 1) {
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;

if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) {
constexpr int cols_per_block = 1;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
Expand Down
Loading
0