8000 cuda : fix flash_attn kernel to produce same results as CPU by ggerganov · Pull Request #3 · Pints-AI/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

cuda : fix flash_attn kernel to produce same results as CPU #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
cuda : fix flash_attn kernel to produce same results as CPU
  • Loading branch information
ggerganov committed Feb 1, 2024
commit 71b69aa7fd0aee89c4750d230bee7a4601d8fc1f
66 changes: 41 additions & 25 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6445,7 +6445,7 @@ static __global__ void flash_attn_ext_f16(
const int D16 = D/16;
const int Q16 = Q/16;
const int NW = WARP_SIZE;
const int SH = (C + Q); // shared memory per simdgroup in (half)
const int SH = (C + 2*Q); // shared memory per simdgroup in (half)

const int T = D + num_warps*SH; // shared memory size per query in (half)
const int T2 = T/2; // shared memory size per query in (half2)
Expand Down Expand Up @@ -6526,11 +6526,16 @@ static __global__ void flash_attn_ext_f16(
}
}

const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;

// pointer to the mask
const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr;

// prepare diagonal scale matrix
half16x16_b mscale;
for (int i = 0; i < 16; ++i) {
ss[i*T + i] = __float2half(scale);
}
nvcuda::wmma::load_matrix_sync(mscale, ss, T);

// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) {
Expand All @@ -6555,10 +6560,15 @@ static __global__ void flash_attn_ext_f16(

// mqk = mqk*scale + mask
for (int64_t j = 0; j < Q16; ++j) {
for (uint32_t i = 0; i < mqk[j].num_elements; i++) {
// TODO: process mask
mqk[j].x[i] = __float2half(scale) * mqk[j].x[i];
}
half16x16_a mqka;
half16x16_acc mm;

// convert accumulator to matrix_a
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
nvcuda::wmma::load_matrix_sync (mqka, ss + 16*j*T + 16*cc, T);

nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major);
nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mm);
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
}
}
Expand Down Expand Up @@ -6631,18 +6641,19 @@ static __global__ void flash_attn_ext_f16(

// O = diag(ms)*O
for (int64_t j = 0; j < Q16; ++j) {
// half16x16_a mm;
// half16x16_b zro;
half16x16_a mm;
half16x16_b lob;

// nvcuda::wmma::fill_fragment(zro, 0.0);
// nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);

for (int64_t i = 0; i < D16; ++i) {
//nvcuda::wmma::mma_sync(lo[j][i], mm, zro, lo[j][i]);
for (uint32_t k = 0; k < 16*16; k++) {
half tmp = ss[(16*j + k%16)*T + C + 16*j + k%16];
lo[j][i].x[k] = tmp * lo[j][i].x[k];
}
// convert accumulator to matrix_b
// TODO: try to avoid the extra QxQ matrix in shared memory needed for this conversion
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + Q, lo[j][i], T, nvcuda::wmma::mem_row_major);
nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + Q, T);

nvcuda::wmma::fill_fragment(lo[j][i], 0.0);
nvcuda::wmma::mma_sync(lo[j][i], mm, lob, lo[j][i]);
}
}

Expand Down Expand Up @@ -6732,10 +6743,11 @@ static __global__ void flash_attn_ext_f16(
nvcuda::wmma::fill_fragment(t2, 0.0);
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
nvcuda::wmma::mma_sync(t2, ms1, t, t2);
// store temporally 'lo' data
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
// load 'lo' data into t
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);

// convert accumulator to matrix_b
nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
nvcuda::wmma::load_matrix_sync (t, sq + 16*j*T + i*16, T);

nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2);
}
}
Expand Down Expand Up @@ -10897,8 +10909,8 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *

GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
GGML_ASSERT(!mask || mask->backend == GGML_BACKEND_GPU);
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 8) &&
"the Flash-Attention CUDA kernel requires the mask to be padded to 8 and at least n_queries big");
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");

ggml_cuda_set_device(g_main_device);
const cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
Expand All @@ -10914,13 +10926,17 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *

const int nqpb = 16; // queries per block
const int ncpw = 32; // cache values per warp (does not work for other values)
// const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, 32)) : 4;
const int nwarps = 1;

const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much?
const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, nwarps_max)) : 4;

dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
dim3 block_dim(32, nwarps, 1);

int shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
// TODO: compare to Metal, here we need extra `nqpb` space in order to do the diag(ms)*O scaling
// try to avoid this
const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + 2*nqpb))*(sizeof(float)/2);

switch (Q->ne[0])
{
case 16:
Expand Down
2 changes: 1 addition & 1 deletion tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2214,7 +2214,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
for (int hs : { 128, }) {
for (int nh : { 32, }) {
for (int kv : { 512, 1024, }) {
for (int nb : { 1, 2, 4, 8, 512 }) {
for (int nb : { 1, 2, 4, 7, 8, 15, 16, 512 }) {
test_cases.emplace_back(new test_attn (hs, nh, kv, nb));
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb));
}
Expand Down
0