@@ -6158,9 +6158,9 @@ static __global__ void flash_attn_f32(
6158
6158
}
6159
6159
6160
6160
#if __CUDA_ARCH__ >= CC_VOLTA
6161
- typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16 , 16 , 16 , half, nvcuda::wmma::col_major > half16x16_a;
6162
- typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16 , 16 , 16 , half, nvcuda::wmma::col_major > half16x16_b;
6163
- typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16 , 16 , 16 , half, nvcuda::wmma::row_major > half16x16_bT;
6161
+ typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16 , 16 , 16 , half, nvcuda::wmma::row_major > half16x16_a;
6162
+ typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16 , 16 , 16 , half, nvcuda::wmma::row_major > half16x16_b;
6163
+ typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16 , 16 , 16 , half, nvcuda::wmma::col_major > half16x16_bT;
6164
6164
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16 , 16 , 16 , half> half16x16_acc;
6165
6165
6166
6166
// based on metal version
@@ -6204,15 +6204,15 @@ static __global__ void flash_attn_ext_f16(
6204
6204
const int D16 = D/16 ;
6205
6205
const int Q16 = Q/16 ;
6206
6206
const int NW = WARP_SIZE;
6207
- const int SH = (C + D ); // shared memory per simdgroup in (half)
6207
+ const int SH = (C + Q ); // shared memory per simdgroup in (half)
6208
6208
6209
6209
const int T = D + num_warps*SH; // shared memory size per query in (half)
6210
6210
const int T2 = T/2 ; // shared memory size per query in (half2)
6211
6211
6212
6212
extern __shared__ half __flash_attn_f16_shmem[];
6213
6213
// pq
6214
- half * sq = (half *) (__flash_attn_f16_shmem + 0 *D); // holds the query data
6215
- half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0 *D); // same as above but in half2
6214
+ half * sq = (half *) (__flash_attn_f16_shmem + 0 *D); // holds the query data
6215
+ half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0 *D); // same as above but in half2
6216
6216
half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1 *D); // scratch buffer for attention and diagonal matrix
6217
6217
half16x16_acc lo[Q16][D16];
6218
6218
@@ -6249,7 +6249,7 @@ static __global__ void flash_attn_ext_f16(
6249
6249
float S[Q];
6250
6250
float M[Q];
6251
6251
6252
- for (int i = 0 ; i < Q;i ++) {
6252
+ for (int i = 0 ; i < Q; i ++) {
6253
6253
S[i] = 0 .0f ;
6254
6254
M[i] = -INFINITY;
6255
6255
}
@@ -6288,7 +6288,7 @@ static __global__ void flash_attn_ext_f16(
6288
6288
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
6289
6289
6290
6290
// pointer to the mask
6291
- const float * mp = (const float *) (mask + (ir%ne31)*nb31);
6291
+ const float * mp = mask ? (const float *) (mask + (ir%ne31)*nb31) : nullptr ;
6292
6292
6293
6293
// loop over the KV cache
6294
6294
// each simdgroup handles blocks of Q rows and C columns
@@ -6305,7 +6305,7 @@ static __global__ void flash_attn_ext_f16(
6305
6305
6306
6306
for (int64_t i = 0 ; i < D16; ++i) {
6307
6307
half16x16_bT mk; // transposed key
6308
- nvcuda::wmma::load_matrix_sync (mk, pk + i*16 , nb11/sizeof (half)); // transpose
6308
+ nvcuda::wmma::load_matrix_sync (mk, pk + i*16 , nb11/sizeof (half));
6309
6309
6310
6310
for (int64_t j = 0 ; j < Q16; ++j) {
6311
6311
nvcuda::wmma::mma_sync (mqk[j], mq[j][i], mk, mqk[j]);
@@ -6314,14 +6314,14 @@ static __global__ void flash_attn_ext_f16(
6314
6314
6315
6315
// mqk = mqk*scale + mask
6316
6316
for (int64_t j = 0 ; j < Q16; ++j) {
6317
- const float * msk_p = mp + 16 *j*(nb31/sizeof (float )) + ic + 16 *cc;
6318
- int64_t msk_ne_row = nb31/sizeof (float );
6317
+ // const float* msk_p = mp + 16*j*(nb31/sizeof(float)) + ic + 16*cc;
6318
+ // int64_t msk_ne_row = nb31/sizeof(float);
6319
6319
for (uint32_t i = 0 ; i < mqk[j].num_elements ; i++) {
6320
- int msk_col = i % 16 ;
6321
- int msk_row = i / 16 ;
6322
- mqk[j].x [i] = __float2half (scale * __half2float ( mqk[j].x [i]) + msk_p[msk_col + msk_row*msk_ne_row]);
6320
+ // int msk_col = i % 16;
6321
+ // int msk_row = i / 16;
6322
+ mqk[j].x [i] = __float2half (scale) * mqk[j].x [i]; // __half2float( ) + msk_p[msk_col + msk_row*msk_ne_row]);
6323
6323
}
6324
- nvcuda::wmma::store_matrix_sync (ss + 16 *j*T + 16 *cc, mqk[j], T, nvcuda::wmma::mem_col_major );
6324
+ nvcuda::wmma::store_matrix_sync (ss + 16 *j*T + 16 *cc, mqk[j], T, nvcuda::wmma::mem_row_major );
6325
6325
}
6326
6326
}
6327
6327
}
@@ -6370,11 +6370,11 @@ static __global__ void flash_attn_ext_f16(
6370
6370
6371
6371
// create a QxQ diagonal matrix for rescaling the output
6372
6372
if (lane_id == j) {
6373
- ss[j*T + C + j] = ms ;
6373
+ ss[j*T + C + j] = __float2half (ms) ;
6374
6374
}
6375
6375
6376
6376
for (int64_t p = lane_id; p < C; p += NW) {
6377
- const float s = ss[j*T + p];
6377
+ const float s = __half2float ( ss[j*T + p]) ;
6378
6378
6379
6379
const float vs = s == -INFINITY ? 0 .0f : expf (s - M[j]);
6380
6380
@@ -6393,14 +6393,18 @@ static __global__ void flash_attn_ext_f16(
6393
6393
6394
6394
// O = diag(ms)*O
6395
6395
for (int64_t j = 0 ; j < Q16; ++j) {
6396
- half16x16_a mm;
6397
- half16x16_b zro;
6396
+ // half16x16_a mm;
6397
+ // half16x16_b zro;
6398
6398
6399
- nvcuda::wmma::fill_fragment (zro, 0.0 );
6400
- nvcuda::wmma::load_matrix_sync (mm, ss + 16 *j*T + C + 16 *j, T);
6399
+ // nvcuda::wmma::fill_fragment(zro, 0.0);
6400
+ // nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
6401
6401
6402
6402
for (int64_t i = 0 ; i < D16; ++i) {
6403
- nvcuda::wmma::mma_sync (lo[j][i], mm, zro, lo[j][i]);
6403
+ // nvcuda::wmma::mma_sync(lo[j][i], mm, zro, lo[j][i]);
6404
+ for (uint32_t k = 0 ; k < 16 *16 ; k++) {
6405
+ half tmp = ss[(16 *j + k%16 )*T + C + 16 *j + k%16 ];
6406
+ lo[j][i].x [k] = tmp * lo[j][i].x [k];
6407
+ }
6404
6408
}
6405
6409
}
6406
6410
@@ -6444,7 +6448,7 @@ static __global__ void flash_attn_ext_f16(
6444
6448
if (warp_id == sg) {
6445
6449
for (int64_t j = 0 ; j < Q16; ++j) {
6446
6450
for (int64_t i = 0 ; i < D16; ++i) {
6447
- nvcuda::wmma::store_matrix_sync (sq + 16 *j*T + i*16 , lo[j][i], T, nvcuda::wmma::mem_col_major );
6451
+ nvcuda::wmma::store_matrix_sync (sq + 16 *j*T + i*16 , lo[j][i], T, nvcuda::wmma::mem_row_major );
6448
6452
}
6449
6453
}
6450
6454
}
@@ -6487,13 +6491,13 @@ static __global__ void flash_attn_ext_f16(
6487
6491
nvcuda::wmma::load_matrix_sync (ms1, ss + 16 *j*T + C + 16 *j + sg*SH, T);
6488
6492
6489
6493
for (int64_t i = 0 ; i < D16; ++i) {
6494
+ nvcuda::wmma::fill_fragment (t2, 0.0 );
6490
6495
nvcuda::wmma::load_matrix_sync (t, sq + 16 *j*T + i*16 , T);
6491
6496
nvcuda::wmma::mma_sync (t2, ms1, t, t2);
6492
-
6493
- // t <- lo
6494
- for (uint32_t k = 0 ; k < t.num_elements ; k++) {
6495
- t.x [k] = lo[j][i].x [k];
6496
- }
6497
+ // store temporally 'lo' data
6498
+ nvcuda::wmma::store_matrix_sync (sq + 16 *j*T + i*16 , lo[j][i], T, nvcuda::wmma::mem_row_major);
6499
+ // load 'lo' data into t
6500
+ nvcuda::wmma::load_matrix_sync (t, sq + 16 *j*T + i*16 , T);
6497
6501
nvcuda::wmma::mma_sync (lo[j][i], ms0, t, t2);
6498
6502
}
6499
6503
}
@@ -6504,22 +6508,20 @@ static __global__ void flash_attn_ext_f16(
6504
6508
if (warp_id == 0 ) {
6505
6509
for (int64_t j = 0 ; j < Q16; ++j) {
6506
6510
for (int64_t i = 0 ; i < D16; ++i) {
6507
- nvcuda::wmma::store_matrix_sync (sq + 16 *j*T + i*16 , lo[j][i], T, nvcuda::wmma::mem_col_major );
6511
+ nvcuda::wmma::store_matrix_sync (sq + 16 *j*T + i*16 , lo[j][i], T, nvcuda::wmma::mem_row_major );
6508
6512
}
6509
6513
}
6510
6514
}
6511
6515
6512
- float2 * dst2 = (float2 *) dst;
6516
+ // float2 * dst2 = (float2 *) dst;
6513
6517
6514
6518
// final rescale with 1/S and store to global memory
6515
6519
if (warp_id == 0 ) {
6516
6520
for (int64_t j = 0 ; j < Q && iq1 + j < ne01; ++j) {
6517
6521
const float S = __half2float (ss[j*T + 0 ]);
6518
6522
6519
- for (int64_t i = lane_id; i < D2; i += NW) {
6520
- dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + i] = __half22float2 (sq2[j*T2 + i]);
6521
- dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + i].x /= S;
6522
- dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + i].y /= S;
6523
+ for (int64_t i = lane_id; i < D; i += NW) {
6524
+ dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float (sq[j*T + i]) / S;
6523
6525
}
6524
6526
}
6525
6527
}
@@ -10526,13 +10528,17 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
10526
10528
GGML_ASSERT (Q->type == GGML_TYPE_F32);
10527
10529
GGML_ASSERT (K->type == GGML_TYPE_F16);
10528
10530
GGML_ASSERT (V->type == GGML_TYPE_F16);
10529
- GGML_ASSERT (mask->type == GGML_TYPE_F32);
10531
+ if (mask) {
10532
+ GGML_ASSERT (mask->type == GGML_TYPE_F32);
10533
+ }
10530
10534
GGML_ASSERT (KQV->type == GGML_TYPE_F32);
10531
10535
10532
10536
GGML_ASSERT (Q->backend == GGML_BACKEND_GPU);
10533
10537
GGML_ASSERT (K->backend == GGML_BACKEND_GPU);
10534
10538
GGML_ASSERT (V->backend == GGML_BACKEND_GPU);
10535
- GGML_ASSERT (mask->backend == GGML_BACKEND_GPU);
10539
+ if (mask) {
10540
+ GGML_ASSERT (mask->backend == GGML_BACKEND_GPU);
10541
+ }
10536
10542
GGML_ASSERT (KQV->backend == GGML_BACKEND_GPU);
10537
10543
10538
10544
ggml_cuda_set_device (g_main_device);
@@ -10541,21 +10547,22 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
10541
10547
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) Q->extra ;
10542
10548
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) K->extra ;
10543
10549
ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) V->extra ;
10544
- ggml_tensor_extra_gpu * src3_extra = (ggml_tensor_extra_gpu *) mask->extra ;
10550
+ ggml_tensor_extra_gpu * src3_extra = mask ? (ggml_tensor_extra_gpu *) mask->extra : nullptr ;
10545
10551
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) KQV->extra ;
10546
10552
10547
10553
float scale;
10548
10554
memcpy (&scale, KQV->op_params , sizeof (float ));
10549
10555
10550
10556
const int nqpb = 16 ; // queries per block
10551
10557
const int ncpw = 32 ; // cache values per warp (does not work for other values)
10552
- const int nwarps = Q->ne [1 ] <= nqpb ? MAX (4 , MIN (K->ne [1 ]/ncpw, 32 )) : 4 ;
10558
+ // const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, 32)) : 4;
10559
+ const int nwarps = 1 ;
10553
10560
10554
10561
dim3 blocks_num ((Q->ne [1 ] + nqpb - 1 ) / nqpb, Q->ne [2 ], Q->ne [3 ]);
10555
10562
dim3 block_dim (32 , nwarps, 1 );
10556
10563
10557
- int shmem = nqpb*(Q->ne [0 ] + nwarps*(Q-> ne [ 0 ] + 1 *ncpw ))*(sizeof (float )/2 );
10558
- printf (" shared memory: %d bytes [%i, %i, %i]\n\n " , shmem, Q->ne [0 ], Q->ne [1 ], Q->ne [2 ]);
10564
+ int shmem = nqpb*(Q->ne [0 ] + nwarps*(ncpw + nqpb ))*(sizeof (float )/2 );
10565
+ printf (" shared memory: %d bytes [%i, %i, %i] scale = %f \n\n " , shmem, Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], scale );
10559
10566
switch (Q->ne [0 ])
10560
10567
{
10561
10568
case 16 :
@@ -10564,12 +10571,12 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
10564
10571
(const char *) src0_extra->data_device [g_main_device], // Query
10565
10572
(const char *) src1_extra->data_device [g_main_device], // Key
10566
10573
(const char *) src2_extra->data_device [g_main_device], // Value
10567
- ( const char *) src3_extra->data_device [g_main_device], // Mask
10574
+ mask ? (( const char *) src3_extra->data_device [g_main_device]) : nullptr , // Mask
10568
10575
(float *) dst_extra->data_device [g_main_device], // dst
10569
10576
scale,
10570
10577
Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
10571
10578
K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
10572
- mask->ne [1 ], mask->nb [1 ],
10579
+ mask ? mask ->ne [1 ] : 0 , mask ? mask ->nb [1 ] : 0 ,
10573
10580
Q->nb [1 ], Q->nb [2 ], Q->nb [3 ],
10574
10581
K->nb [1 ], K->nb [2 ], K->nb [3 ],
10575
10582
KQV->ne [0 ], KQV->ne [1 ], KQV->ne [2 ], KQV->ne [3 ]
@@ -10581,12 +10588,12 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
10581
10588
(const char *) src0_extra->data_device [g_main_device], // Query
10582
10589
(const char *) src1_extra->data_device [g_main_device], // Key
10583
10590
(const char *) src2_extra->data_device [g_main_device], // Value
10584
- ( const char *) src3_extra->data_device [g_main_device], // Mask
10591
+ mask ? (( const char *) src3_extra->data_device [g_main_device]) : nullptr , // Mask
10585
10592
(float *) dst_extra->data_device [g_main_device], // dst
10586
10593
scale,
10587
10594
Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
10588
10595
K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
10589
- mask->ne [1 ], mask->nb [1 ],
10596
+ mask ? mask ->ne [1 ] : 0 , mask ? mask ->nb [1 ] : 0 ,
10590
10597
Q->nb [1 ], Q->nb [2 ], Q->nb [3 ],
10591
10598
K->nb [1 ], K->nb [2 ], K->nb [3 ],
10592
10599
KQV->ne [0 ], KQV->ne [1 ], KQV->ne [2 ], KQV->ne [3 ]
@@ -10598,12 +10605,12 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
10598
10605
(const char *) src0_extra->data_device [g_main_device], // Query
10599
10606
(const char *) src1_extra->data_device [g_main_device], // Key
10600
10607
(const char *) src2_extra->data_device [g_main_device], // Value
10601
- ( const char *) src3_extra->data_device [g_main_device], // Mask
10608
+ mask ? (( const char *) src3_extra->data_device [g_main_device]) : nullptr , // Mask
10602
10609
(float *) dst_extra->data_device [g_main_device], // dst
10603
10610
scale,
10604
10611
Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
10605
10612
K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
10606
- mask->ne [1 ], mask->nb [1 ],
10613
+ mask ? mask ->ne [1 ] : 0 , mask ? mask ->nb [1 ] : 0 ,
10607
10614
Q->nb [1 ], Q->nb [2 ], Q->nb [3 ],
10608
10615
K->nb [1 ], K->nb [2 ], K->nb [3 ],
10609
10616
KQV->ne [0 ], KQV->ne [1 ], KQV->ne [2 ], KQV->ne [3 ]
@@ -10615,12 +10622,12 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
10615
10622
(const char *) src0_extra->data_device [g_main_device], // Query
10616
10623
(const char *) src1_extra->data_device [g_main_device], // Key
10617
10624
(const char *) src2_extra->data_device [g_main_device], // Value
10618
- ( const char *) src3_extra->data_device [g_main_device], // Mask
10625
+ mask ? (( const char *) src3_extra->data_device [g_main_device]) : nullptr , // Mask
10619
10626
(float *) dst_extra->data_device [g_main_device], // dst
10620
10627
scale,
10621
10628
Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
10622
10629
K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
10623
- mask->ne [1 ], mask->nb [1 ],
10630
+ mask ? mask ->ne [1 ] : 0 , mask ? mask ->nb [1 ] : 0 ,
10624
10631
Q->nb [1 ], Q->nb [2 ], Q->nb [3 ],
10625
10632
K->nb [1 ], K->nb [2 ], K->nb [3 ],
10626
10633
KQV->ne [0 ], KQV->ne [1 ], KQV->ne [2 ], KQV->ne [3 ]
0 commit comments