8000 fix kernel · Pints-AI/llama.cpp@b1479df · GitHub
[go: up one dir, main page]

Skip to content

Commit b1479df

Browse files
committed
fix kernel
1 parent 3b0f74b commit b1479df

File tree

2 files changed

+56
-49
lines changed

2 files changed

+56
-49
lines changed

ggml-cuda.cu

Lines changed: 55 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -6158,9 +6158,9 @@ static __global__ void flash_attn_f32(
61586158
}
61596159

61606160
#if __CUDA_ARCH__ >= CC_VOLTA
6161-
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_a;
6162-
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_b;
6163-
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_bT;
6161+
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_a;
6162+
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_b;
6163+
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_bT;
61646164
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, half> half16x16_acc;
61656165

61666166
// based on metal version
@@ -6204,15 +6204,15 @@ static __global__ void flash_attn_ext_f16(
62046204
const int D16 = D/16;
62056205
const int Q16 = Q/16;
62066206
const int NW = WARP_SIZE;
6207-
const int SH = (C + D); // shared memory per simdgroup in (half)
6207+
const int SH = (C + Q); // shared memory per simdgroup in (half)
62086208

62096209
const int T = D + num_warps*SH; // shared memory size per query in (half)
62106210
const int T2 = T/2; // shared memory size per query in (half2)
62116211

62126212
extern __shared__ half __flash_attn_f16_shmem[];
62136213
// pq
6214-
half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data
6215-
half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2
6214+
half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data
6215+
half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2
62166216
half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix
62176217
half16x16_acc lo[Q16][D16];
62186218

@@ -6249,7 +6249,7 @@ static __global__ void flash_attn_ext_f16(
62496249
float S[Q];
62506250
float M[Q];
62516251

6252-
for(int i = 0; i < Q;i ++) {
6252+
for(int i = 0; i < Q; i++) {
62536253
S[i] = 0.0f;
62546254
M[i] = -INFINITY;
62556255
}
@@ -6288,7 +6288,7 @@ static __global__ void flash_attn_ext_f16(
62886288
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
62896289

62906290
// pointer to the mask
6291-
const float * mp = (const float *) (mask + (ir%ne31)*nb31);
6291+
const float * mp = mask ? (const float *) (mask + (ir%ne31)*nb31) : nullptr;
62926292

62936293
// loop over the KV cache
62946294
// each simdgroup handles blocks of Q rows and C columns
@@ -6305,7 +6305,7 @@ static __global__ void flash_attn_ext_f16(
63056305

63066306
for (int64_t i = 0; i < D16; ++i) {
63076307
half16x16_bT mk; // transposed key
6308-
nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); // transpose
6308+
nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half));
63096309

63106310
for (int64_t j = 0; j < Q16; ++j) {
63116311
nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]);
@@ -6314,14 +6314,14 @@ static __global__ void flash_attn_ext_f16(
63146314

63156315
// mqk = mqk*scale + mask
63166316
for (int64_t j = 0; j < Q16; ++j) {
6317-
const float* msk_p = mp + 16*j*(nb31/sizeof(float)) + ic + 16*cc;
6318-
int64_t msk_ne_row = nb31/sizeof(float);
6317+
// const float* msk_p = mp + 16*j*(nb31/sizeof(float)) + ic + 16*cc;
6318+
// int64_t msk_ne_row = nb31/sizeof(float);
63196319
for (uint32_t i = 0; i < mqk[j].num_elements; i++) {
6320-
int msk_col = i % 16;
6321-
int msk_row = i / 16;
6322-
mqk[j].x[i] = __float2half(scale * __half2float(mqk[j].x[i]) + msk_p[msk_col + msk_row*msk_ne_row]);
6320+
// int msk_col = i % 16;
6321+
// int msk_row = i / 16;
6322+
mqk[j].x[i] = __float2half(scale) * mqk[j].x[i]; // __half2float() + msk_p[msk_col + msk_row*msk_ne_row]);
63236323
}
6324-
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_col_major);
6324+
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
63256325
}
63266326
}
63276327
}
@@ -6370,11 +6370,11 @@ static __global__ void flash_attn_ext_f16(
63706370

63716371
// create a QxQ diagonal matrix for rescaling the output
63726372
if (lane_id == j) {
6373-
ss[j*T + C + j] = ms;
6373+
ss[j*T + C + j] = __float2half(ms);
63746374
}
63756375

63766376
for (int64_t p = lane_id; p < C; p += NW) {
6377-
const float s = ss[j*T + p];
6377+
const float s = __half2float(ss[j*T + p]);
63786378

63796379
const float vs = s == -INFINITY ? 0.0f : expf(s - M[j]);
63806380

@@ -6393,14 +6393,18 @@ static __global__ void flash_attn_ext_f16(
63936393

63946394
// O = diag(ms)*O
63956395
for (int64_t j = 0; j < Q16; ++j) {
6396-
half16x16_a mm;
6397-
half16x16_b zro;
6396+
// half16x16_a mm;
6397+
// half16x16_b zro;
63986398

6399-
nvcuda::wmma::fill_fragment(zro, 0.0);
6400-
nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
6399+
// nvcuda::wmma::fill_fragment(zro, 0.0);
6400+
// nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
64016401

64026402
for (int64_t i = 0; i < D16; ++i) {
6403-
nvcuda::wmma::mma_sync(lo[j][i], mm, zro, lo[j][i]);
6403+
//nvcuda::wmma::mma_sync(lo[j][i], mm, zro, lo[j][i]);
6404+
for (uint32_t k = 0; k < 16*16; k++) {
6405+
half tmp = ss[(16*j + k%16)*T + C + 16*j + k%16];
6406+
lo[j][i].x[k] = tmp * lo[j][i].x[k];
6407+
}
64046408
}
64056409
}
64066410

@@ -6444,7 +6448,7 @@ static __global__ void flash_attn_ext_f16(
64446448
if (warp_id == sg) {
64456449
for (int64_t j = 0; j < Q16; ++j) {
64466450
for (int64_t i = 0; i < D16; ++i) {
6447-
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_col_major);
6451+
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
64486452
}
64496453
}
64506454
}
@@ -6487,13 +6491,13 @@ static __global__ void flash_attn_ext_f16(
64876491
nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T);
64886492

64896493
for (int64_t i = 0; i < D16; ++i) {
6494+
nvcuda::wmma::fill_fragment(t2, 0.0);
64906495
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
64916496
nvcuda::wmma::mma_sync(t2, ms1, t, t2);
6492-
6493-
// t <- lo
6494-
for (uint32_t k = 0; k < t.num_elements; k++) {
6495-
t.x[k] = lo[j][i].x[k];
6496-
}
6497+
// store temporally 'lo' data
6498+
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
6499+
// load 'lo' data into t
6500+
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
64976501
nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2);
64986502
}
64996503
}
@@ -6504,22 +6508,20 @@ static __global__ void flash_attn_ext_f16(
65046508
if (warp_id == 0) {
65056509
for (int64_t j = 0; j < Q16; ++j) {
65066510
for (int64_t i = 0; i < D16; ++i) {
6507-
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_col_major);
6511+
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
65086512
}
65096513
}
65106514
}
65116515

6512-
float2 * dst2 = (float2 *) dst;
6516+
// float2 * dst2 = (float2 *) dst;
65136517

65146518
// final rescale with 1/S and store to global memory
65156519
if (warp_id == 0) {
65166520
for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) {
65176521
const float S = __half2float(ss[j*T + 0]);
65186522

6519-
for (int64_t i = lane_id; i < D2; i += NW) {
6520-
dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + i] = __half22float2(sq2[j*T2 + i]);
6521-
dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + i].x /= S;
6522-
dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + i].y /= S;
6523+
for (int64_t i = lane_id; i < D; i += NW) {
6524+
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i]) / S;
65236525
}
65246526
}
65256527
}
@@ -10526,13 +10528,17 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
1052610528
GGML_ASSERT(Q->type == GGML_TYPE_F32);
1052710529
GGML_ASSERT(K->type == GGML_TYPE_F16);
1052810530
GGML_ASSERT(V->type == GGML_TYPE_F16);
10529-
GGML_ASSERT(mask->type == GGML_TYPE_F32);
10531+
if(mask) {
10532+
GGML_ASSERT(mask->type == GGML_TYPE_F32);
10533+
}
1053010534
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
1053110535

1053210536
GGML_ASSERT(Q->backend == GGML_BACKEND_GPU);
1053310537
GGML_ASSERT(K->backend == GGML_BACKEND_GPU);
1053410538
GGML_ASSERT(V->backend == GGML_BACKEND_GPU);
10535-
GGML_ASSERT(mask->backend == GGML_BACKEND_GPU);
10539+
if(mask) {
10540+
GGML_ASSERT(mask->backend == GGML_BACKEND_GPU);
10541+
}
1053610542
GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU);
1053710543

1053810544
ggml_cuda_set_device(g_main_device);
@@ -10541,21 +10547,22 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
1054110547
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) Q->extra;
1054210548
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) K->extra;
1054310549
ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) V->extra;
10544-
ggml_tensor_extra_gpu * src3_extra = (ggml_tensor_extra_gpu *) mask->extra;
10550+
ggml_tensor_extra_gpu * src3_extra = mask ? (ggml_tensor_extra_gpu *) mask->extra : nullptr;
1054510551
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) KQV->extra;
1054610552

1054710553
float scale;
1054810554
memcpy(&scale, KQV->op_params, sizeof(float));
1054910555

1055010556
const int nqpb = 16; // queries per block
1055110557
const int ncpw = 32; // cache values per warp (does not work for other values)
10552-
const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, 32)) : 4;
10558+
// const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, 32)) : 4;
10559+
const int nwarps = 1;
1055310560

1055410561
dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
1055510562
dim3 block_dim(32, nwarps, 1);
1055610563

10557-
int shmem = nqpb*(Q->ne[0] + nwarps*(Q->ne[0] + 1*ncpw))*(sizeof(float)/2);
10558-
printf("shared memory: %d bytes [%i, %i, %i]\n\n", shmem, Q->ne[0], Q->ne[1], Q->ne[2]);
10564+
int shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
10565+
printf("shared memory: %d bytes [%i, %i, %i] scale = %f\n\n", shmem, Q->ne[0], Q->ne[1], Q->ne[2], scale);
1055910566
switch (Q->ne[0])
1056010567
{
1056110568
case 16:
@@ -10564,12 +10571,12 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
1056410571
(const char *) src0_extra->data_device[g_main_device], // Query
1056510572
(const char *) src1_extra->data_device[g_main_device], // Key
1056610573
(const char *) src2_extra->data_device[g_main_device], // Value
10567-
(const char *) src3_extra->data_device[g_main_device], // Mask
10574+
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
1056810575
(float *) dst_extra->data_device[g_main_device], // dst
1056910576
scale,
1057010577
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
1057110578
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
10572-
mask->ne[1], mask->nb[1],
10579+
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
1057310580
Q->nb[1], Q->nb[2], Q->nb[3],
1057410581
K->nb[1], K->nb[2], K->nb[3],
1057510582
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
@@ -10581,12 +10588,12 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
1058110588
(const char *) src0_extra->data_device[g_main_device], // Query
1058210589
(const char *) src1_extra->data_device[g_main_device], // Key
1058310590
(const char *) src2_extra->data_device[g_main_device], // Value
10584-
(const char *) src3_extra->data_device[g_main_device], // Mask
10591+
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
1058510592
(float *) dst_extra->data_device[g_main_device], // dst
1058610593
scale,
1058710594
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
1058810595
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
10589-
mask->ne[1], mask->nb[1],
10596+
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
1059010597
Q->nb[1], Q->nb[2], Q->nb[3],
1059110598
K->nb[1], K->nb[2], K->nb[3],
1059210599
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
@@ -10598,12 +10605,12 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
1059810605
(const char *) src0_extra->data_device[g_main_device], // Query
1059910606
(const char *) src1_extra->data_device[g_main_device], // Key
1060010607
(const char *) src2_extra->data_device[g_main_device], // Value
10601-
(const char *) src3_extra->data_device[g_main_device], // Mask
10608+
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
1060210609
(float *) dst_extra->data_device[g_main_device], // dst
1060310610
scale,
1060410611
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
1060510612
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
10606-
mask->ne[1], mask->nb[1],
10613+
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
1060710614
Q->nb[1], Q->nb[2], Q->nb[3],
1060810615
K->nb[1], K->nb[2], K->nb[3],
1060910616
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
@@ -10615,12 +10622,12 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
1061510622
(const char *) src0_extra->data_device[g_main_device], // Query
1061610623
(const char *) src1_extra->data_device[g_main_device], // Key
1061710624
(const char *) src2_extra->data_device[g_main_device], // Value
10618-
(const char *) src3_extra->data_device[g_main_device], // Mask
10625+
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
1061910626
(float *) dst_extra->data_device[g_main_device], // dst
1062010627
scale,
1062110628
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
1062210629
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
10623-
mask->ne[1], mask->nb[1],
10630+
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
1062410631
Q->nb[1], Q->nb[2], Q->nb[3],
1062510632
K->nb[1], K->nb[2], K->nb[3],
1062610633
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]

tests/test-flash-attention.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * a
201201
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
202202

203203
if(!model.naive_attn) {
204-
struct ggml_tensor* result = ggml_flash_attn_ext(ctx0, model.q, model.k, model.v, model.msk, 1.0f / sqrtf(model.q->ne[0]));
204+
struct ggml_tensor* result = ggml_flash_attn_ext(ctx0, model.q, model.k, model.v, nullptr, 1.0f / sqrtf(model.q->ne[0]));
205205
ggml_build_forward_expand(gf, result);
206206
} else {
207207
struct ggml_tensor* kq = ggml_mul_mat(ctx0, model.k, model.q);

0 commit comments

Comments
 (0)
0