-
Notifications
You must be signed in to change notification settings - Fork 12k
CUDA: skip fully masked-out KV in FA vec kernel #13584
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,9 +2,9 @@ | |
#include "fattn-common.cuh" | ||
|
||
template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size | ||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||
#ifndef GGML_USE_HIP | ||
__launch_bounds__(D, 1) | ||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||
#endif // GGML_USE_HIP | ||
static __global__ void flash_attn_vec_ext_f32( | ||
const char * __restrict__ Q, | ||
const char * __restrict__ K, | ||
|
@@ -60,6 +60,12 @@ static __global__ void flash_attn_vec_ext_f32( | |
NO_DEVICE_CODE; | ||
return; | ||
} | ||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) | ||
if (ncols > 1) { | ||
NO_DEVICE_CODE; | ||
return; | ||
} | ||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) | ||
|
||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices. | ||
|
||
|
@@ -104,6 +110,13 @@ static __global__ void flash_attn_vec_ext_f32( | |
kqsum_shared[j][threadIdx.x] = 0.0f; | ||
} | ||
} | ||
|
||
__shared__ float maskf_shared[ncols*D]; | ||
#pragma unroll | ||
for (int j = 0; j < ncols; ++j) { | ||
maskf_shared[j*D + tid] = 0.0f; | ||
} | ||
|
||
__syncthreads(); | ||
|
||
// Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers: | ||
|
@@ -181,6 +194,34 @@ static __global__ void flash_attn_vec_ext_f32( | |
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) { | ||
// Calculate KQ tile and keep track of new maximum KQ values: | ||
|
||
if (mask) { | ||
#pragma unroll | ||
for (int j = 0; j < ncols; ++j) { | ||
maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + k_VKQ_0 + tid]); | ||
} | ||
|
||
__syncthreads(); | ||
|
||
// When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out. | ||
// In such cases, skip the KV slice. | ||
// On AMD __all_sync would not work correctly because it assumes a warp size of 64. | ||
#ifndef GGML_USE_HIP | ||
bool skip = true; | ||
#pragma unroll | ||
for (int j = 0; j < ncols; ++j) { | ||
#pragma unroll | ||
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { | ||
const int i = i0 + threadIdx.x; | ||
|
||
skip = skip && isinf(maskf_shared[j*D + i]); | ||
} | ||
} | ||
if (__all_sync(0xFFFFFFFF, skip)) { | ||
continue; | ||
} | ||
Comment on lines
+219
to
+221
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we need one more sync here: if (__all_sync(0xFFFFFFFF, skip)) {
__syncthreads();
continue;
} This change fixes it for me. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I think you're right. I initially only read your other comment above but when I looked at the code again myself I came to the same conclusion. |
||
#endif // GGML_USE_HIP | ||
} | ||
|
||
float kqmax_new_arr[ncols]; | ||
#pragma unroll | ||
for (int j = 0; j < ncols; ++j) { | ||
|
@@ -204,7 +245,7 @@ static __global__ void flash_attn_vec_ext_f32( | |
sum = logit_softcap*tanhf(sum); | ||
} | ||
|
||
sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; | ||
sum += maskf_shared[j*D + i_KQ]; | ||
|
||
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum); | ||
|
||
|
@@ -326,7 +367,9 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml | |
float logit_softcap; | ||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); | ||
|
||
if (Q->ne[1] == 1) { | ||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; | ||
|
||
if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) { | ||
constexpr int cols_per_block = 1; | ||
if (logit_softcap == 0.0f) { | ||
constexpr bool use_logit_softcap = false; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JohannesGaessler I think this is missing an inner loop over
D
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, because for this kernel the number of threads is equal to D.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm looking into the bug report #13733 and it seems there is problem with the change in this PR.
I can reproduce with this command:
This will output repetitive junk.
If apply this patch to disable the skipping logic, it works:
Any ideas what could be wrong?