@@ -6467,10 +6467,22 @@ static __global__ void flash_attn_ext_f16(
6467
6467
half16x16_acc lo[Q16][D16];
6468
6468
6469
6469
// load heads from Q to shared memory
6470
- for (int64_t j = warp_id; j < Q; j += num_warps) {
6470
+ #pragma unroll
6471
+ for (int j0 = 0 ; j0 < Q; j0 += num_warps) {
6472
+ const int j = j0 + warp_id;
6473
+ if (j >= Q) {
6474
+ break ;
6475
+ }
6476
+
6471
6477
const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
6472
6478
6473
- for (int64_t i = lane_id; i < D2; i += NW) {
6479
+ #pragma unroll
6480
+ for (int i0 = 0 ; i0 < D2; i0 += NW) {
6481
+ const int i = i0 + lane_id;
6482
+ if (i >= D2) {
6483
+ break ;
6484
+ }
6485
+
6474
6486
if (iq1 + j < ne01) {
6475
6487
sq2[j*T2 + i] = __float22half2_rn (q2[i]);
6476
6488
} else {
@@ -6482,15 +6494,15 @@ static __global__ void flash_attn_ext_f16(
6482
6494
nvcuda::wmma::fill_fragment (zr, 0.0 );
6483
6495
6484
6496
// zero out lo
6485
- for (int64_t j = 0 ; j < Q16; ++j) {
6486
- for (int64_t i = 0 ; i < D16; ++i) {
6497
+ for (int j = 0 ; j < Q16; ++j) {
6498
+ for (int i = 0 ; i < D16; ++i) {
6487
6499
nvcuda::wmma::fill_fragment (lo[j][i], 0.0 );
6488
6500
}
6489
6501
}
6490
6502
6491
6503
// zero out shared memory SH
6492
- for (int64_t j = 0 ; j < Q; ++j) {
6493
- for (int64_t i = lane_id; i < SH; i += NW) {
6504
+ for (int j = 0 ; j < Q; ++j) {
6505
+ for (int i = lane_id; i < SH; i += NW) {
6494
6506
ss[j*T + i] = 0.0 ;
6495
6507
}
6496
6508
}
@@ -6531,8 +6543,8 @@ static __global__ void flash_attn_ext_f16(
6531
6543
6532
6544
// load the queries from shared memory into local memory
6533
6545
half16x16_a mq[Q16][D16];
6534
- for (int64_t j = 0 ; j < Q16; ++j) {
6535
- for (int64_t i = 0 ; i < D16; ++i) {
6546
+ for (int j = 0 ; j < Q16; ++j) {
6547
+ for (int i = 0 ; i < D16; ++i) {
6536
6548
nvcuda::wmma::load_matrix_sync (mq[j][i], sq + 16 *j*T + i*16 , T);
6537
6549
}
6538
6550
}
@@ -6549,28 +6561,28 @@ static __global__ void flash_attn_ext_f16(
6549
6561
6550
6562
// loop over the KV cache
6551
6563
// each simdgroup handles blocks of Q rows and C columns
6552
- for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) {
6564
+ for (int ic = C*warp_id; ic < ne11; ic += C*num_warps) {
6553
6565
// Q*K^T
6554
6566
{
6555
6567
for (int cc = 0 ; cc < C/16 ; ++cc) {
6556
6568
half16x16_acc mqk[Q16];
6557
- for (int64_t j = 0 ; j < Q16; ++j) {
6569
+ for (int j = 0 ; j < Q16; ++j) {
6558
6570
nvcuda::wmma::fill_fragment (mqk[j], 0 );
6559
6571
}
6560
6572
6561
6573
const half * pk = (const half *) ((const char *) k + ((ic + 16 *cc)*nb11 + ik2*nb12 + ik3*nb13));
6562
6574
6563
- for (int64_t i = 0 ; i < D16; ++i) {
6575
+ for (int i = 0 ; i < D16; ++i) {
6564
6576
half16x16_bT mk; // transposed key
6565
6577
nvcuda::wmma::load_matrix_sync (mk, pk + i*16 , nb11/sizeof (half));
6566
6578
6567
- for (int64_t j = 0 ; j < Q16; ++j) {
6579
+ for (int j = 0 ; j < Q16; ++j) {
6568
6580
nvcuda::wmma::mma_sync (mqk[j], mq[j][i], mk, mqk[j]);
6569
6581
}
6570
6582
}
6571
6583
6572
6584
// mqk = mqk*scale + mask
6573
- for (int64_t j = 0 ; j < Q16; ++j) {
6585
+ for (int j = 0 ; j < Q16; ++j) {
6574
6586
half16x16_a mqka;
6575
6587
half16x16_acc mm;
6576
6588
if (mp) {
@@ -6592,8 +6604,8 @@ static __global__ void flash_attn_ext_f16(
6592
6604
6593
6605
// online softmax
6594
6606
if (C == 32 ) {
6595
- for (int64_t j = 0 ; j < Q; ++j) {
6596
- const int64_t p = lane_id;
6607
+ for (int j = 0 ; j < Q; ++j) {
6608
+ const int p = lane_id;
6597
6609
6598
6610
const half m = M[j];
6599
6611
const half s = ss[j*T + p];
@@ -6615,10 +6627,10 @@ static __global__ void flash_attn_ext_f16(
6615
6627
ss[j*T + p] = vs;
6616
6628
}
6617
6629
} else {
6618
- for (int64_t j = 0 ; j < Q; ++j) {
6630
+ for (int j = 0 ; j < Q; ++j) {
6619
6631
const half m = M[j];
6620
6632
6621
- for (int64_t p = lane_id; p < C; p += NW) {
6633
+ for (int p = lane_id; p < C; p += NW) {
6622
6634
const half s = ss[j*T + p];
6623
6635
6624
6636
smax = __hmax (smax, s);
@@ -6638,7 +6650,7 @@ static __global__ void flash_attn_ext_f16(
6638
6650
// local sum
6639
6651
half ls = 0 .0f ;
6640
6652
6641
- for (int64_t p = lane_id; p < C; p += NW) {
6653
+ for (int p = lane_id; p < C; p += NW) {
6642
6654
const half s = ss[j*T + p];
6643
6655
6644
6656
const half vs = __hisinf (s) == -1 ? __float2half (0 .0f ) : hexp (s - M[j]);
@@ -6659,13 +6671,13 @@ static __global__ void flash_attn_ext_f16(
6659
6671
}
6660
6672
6661
6673
// O = diag(ms)*O
6662
- for (int64_t j = 0 ; j < Q16; ++j) {
6674
+ for (int j = 0 ; j < Q16; ++j) {
6663
6675
half16x16_a mm;
6664
6676
half16x16_b lob;
6665
6677
6666
6678
nvcuda::wmma::load_matrix_sync (mm, ss + 16 *j*T + C + 16 *j, T);
6667
6679
6668
- for (int64_t i = 0 ; i < D16; ++i) {
6680
+ for (int i = 0 ; i < D16; ++i) {
6669
6681
// convert accumulator to matrix_b
6670
6682
nvcuda::wmma::store_matrix_sync ( ss + 16 *j*T + C + 16 *j, lo[j][i], T, nvcuda::wmma::mem_row_major);
6671
6683
nvcuda::wmma::load_matrix_sync (lob, ss + 16 *j*T + C + 16 *j, T);
@@ -6684,17 +6696,17 @@ static __global__ void flash_attn_ext_f16(
6684
6696
const half * pv = (const half *) ((const char *) v + ((ic + 16 *cc)*nb21 + iv2*nb22 + iv3*nb23));
6685
6697
6686
6698
half16x16_b mk[D16];
6687
- for (int64_t i = 0 ; i < D16; ++i) {
6699
+ for (int i = 0 ; i < D16; ++i) {
6688
6700
nvcuda::wmma::load_matrix_sync (mk[i], pv + i*16 , nb21/sizeof (half));
6689
6701
}
6690
6702
6691
6703
half16x16_a mv[Q16];
6692
- for (int64_t j = 0 ; j < Q16; ++j) {
6704
+ for (int j = 0 ; j < Q16; ++j) {
6693
6705
nvcuda::wmma::load_matrix_sync (mv[j], ss + 16 *j*T + 16 *cc, T);
6694
6706
}
6695
6707
6696
- for (int64_t j = 0 ; j < Q16; ++j) {
6697
- for (int64_t i = 0 ; i < D16; ++i) {
6708
+ for (int j = 0 ; j < Q16; ++j) {
6709
+ for (int i = 0 ; i < D16; ++i) {
6698
6710
nvcuda::wmma::mma_sync (lo[j][i], mv[j], mk[i], lo[j][i]);
6699
6711
}
6700
6712
}
@@ -6703,7 +6715,7 @@ static __global__ void flash_attn_ext_f16(
6703
6715
}
6704
6716
6705
6717
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
6706
- for (int64_t j = 0 ; j < Q; ++j) {
6718
+ for (int j = 0 ; j < Q; ++j) {
6707
6719
if (lane_id == 0 ) {
6708
6720
ss[j*T + 0 ] = S[j];
6709
6721
ss[j*T + 1 ] = M[j];
@@ -6712,16 +6724,16 @@ static __global__ void flash_attn_ext_f16(
6712
6724
}
6713
6725
6714
6726
// reduce the warps sequentially
6715
- for (int64_t sg = 1 ; sg < num_warps; ++sg) {
6727
+ for (int sg = 1 ; sg < num_warps; ++sg) {
6716
6728
half S = __float2half (0 .0f );
6717
6729
half M = __float2half (-INFINITY);
6718
6730
6719
6731
__syncthreads ();
6720
6732
6721
6733
// each simdgroup stores its output to shared memory, reusing sq
6722
6734
if (warp_id == sg) {
6723
- for (int64_t j = 0 ; j < Q16; ++j) {
6724
- for (int64_t i = 0 ; i < D16; ++i) {
6735
+ for (int j = 0 ; j < Q16; ++j) {
6736
+ for (int i = 0 ; i < D16; ++i) {
6725
6737
nvcuda::wmma::store_matrix_sync (sq + 16 *j*T + i*16 , lo[j][i], T, nvcuda::wmma::mem_row_major);
6726
6738
}
6727
6739
}
@@ -6731,7 +6743,7 @@ static __global__ void flash_attn_ext_f16(
6731
6743
6732
6744
// the first simdgroup accumulates the results from the other simdgroups
6733
6745
if (warp_id == 0 ) {
6734
- for (int64_t j = 0 ; j < Q; ++j) {
6746
+ for (int j = 0 ; j < Q; ++j) {
6735
6747
const half S0 = ss[j*T + 0 ];
6736
6748
const half S1 = ss[j*T + sg*SH + 0 ];
6737
6749
@@ -6755,7 +6767,7 @@ static __global__ void flash_attn_ext_f16(
6755
6767
}
6756
6768
6757
6769
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
6758
- for (int64_t j = 0 ; j < Q16; ++j) {
6770
+ for (int j = 0 ; j < Q16; ++j) {
6759
6771
half16x16_a ms0;
6760
6772
half16x16_a ms1;
6761
6773
half16x16_b t;
@@ -6764,7 +6776,7 @@ static __global__ void flash_attn_ext_f16(
6764
6776
nvcuda::wmma::load_matrix_sync (ms0, ss + 16 *j*T + C + 16 *j, T);
6765
6777
nvcuda::wmma::load_matrix_sync (ms1, ss + 16 *j*T + C + 16 *j + sg*SH, T);
6766
6778
6767
- for (int64_t i = 0 ; i < D16; ++i) {
6779
+ for (int i = 0 ; i < D16; ++i) {
6768
6780
nvcuda::wmma::fill_fragment (t2, 0.0 );
6769
6781
nvcuda::wmma::load_matrix_sync (t, sq + 16 *j*T + i*16 , T);
6770
6782
nvcuda::wmma::mma_sync (t2, ms1, t, t2);
@@ -6781,19 +6793,19 @@ static __global__ void flash_attn_ext_f16(
6781
6793
6782
6794
// store result to shared memory (reuse sq)
6783
6795
if (warp_id == 0 ) {
6784
- for (int64_t j = 0 ; j < Q16; ++j) {
6785
- for (int64_t i = 0 ; i < D16; ++i) {
6796
+ for (int j = 0 ; j < Q16; ++j) {
6797
+ for (int i = 0 ; i < D16; ++i) {
6786
6798
nvcuda::wmma::store_matrix_sync (sq + 16 *j*T + i*16 , lo[j][i], T, nvcuda::wmma::mem_row_major);
6787
6799
}
6788
6800
}
6789
6801
}
6790
6802
6791
6803
// final rescale with 1/S and store to global memory
6792
6804
if (warp_id == 0 ) {
6793
- for (int64_t j = 0 ; j < Q && iq1 + j < ne01; ++j) {
6805
+ for (int j = 0 ; j < Q && iq1 + j < ne01; ++j) {
6794
6806
const half S = ss[j*T + 0 ];
6795
6807
6796
- for (int64_t i = lane_id; i < D; i += NW) {
6808
+ for (int i = lane_id; i < D; i += NW) {
6797
6809
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float (sq[j*T + i] / S);
6798
6810
}
6799
6811
}
0 commit comments