8000 unroll 2 loops, int64_t -> int, 309 µs · Pints-AI/llama.cpp@674d5ac · GitHub
[go: up one dir, main page]

Skip to content

Commit 674d5ac

Browse files
unroll 2 loops, int64_t -> int, 309 µs
1 parent 53621e3 commit 674d5ac

File tree

1 file changed

+47
-35
lines changed

1 file changed

+47
-35
lines changed

ggml-cuda.cu

Lines changed: 47 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6467,10 +6467,22 @@ static __global__ void flash_attn_ext_f16(
64676467
half16x16_acc lo[Q16][D16];
64686468

64696469
// load heads from Q to shared memory
6470-
for (int64_t j = warp_id; j < Q; j += num_warps) {
6470+
#pragma unroll
6471+
for (int j0 = 0; j0 < Q; j0 += num_warps) {
6472+
const int j = j0 + warp_id;
6473+
if (j >= Q) {
6474+
break;
6475+
}
6476+
64716477
const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
64726478

6473-
for (int64_t i = lane_id; i < D2; i += NW) {
6479+
#pragma unroll
6480+
for (int i0 = 0; i0 < D2; i0 += NW) {
6481+
const int i = i0 + lane_id;
6482+
if (i >= D2) {
6483+
break;
6484+
}
6485+
64746486
if (iq1 + j < ne01) {
64756487
sq2[j*T2 + i] = __float22half2_rn(q2[i]);
64766488
} else {
@@ -6482,15 +6494,15 @@ static __global__ void flash_attn_ext_f16(
64826494
nvcuda::wmma::fill_fragment(zr, 0.0);
64836495

64846496
// zero out lo
6485-
for (int64_t j = 0; j < Q16; ++j) {
6486-
for (int64_t i = 0; i < D16; ++i) {
6497+
for (int j = 0; j < Q16; ++j) {
6498+
for (int i = 0; i < D16; ++i) {
64876499
nvcuda::wmma::fill_fragment(lo[j][i], 0.0);
64886500
}
64896501
}
64906502

64916503
// zero out shared memory SH
6492-
for (int64_t j = 0; j < Q; ++j) {
6493-
for (int64_t i = lane_id; i < SH; i += NW) {
6504+
for (int j = 0; j < Q; ++j) {
6505+
for (int i = lane_id; i < SH; i += NW) {
64946506
ss[j*T + i] = 0.0;
64956507
}
64966508
}
@@ -6531,8 +6543,8 @@ static __global__ void flash_attn_ext_f16(
65316543

65326544
// load the queries from shared memory into local memory
65336545
half16x16_a mq[Q16][D16];
6534-
for (int64_t j = 0; j < Q16; ++j) {
6535-
for (int64_t i = 0; i < D16; ++i) {
6546+
for (int j = 0; j < Q16; ++j) {
6547+
for (int i = 0; i < D16; ++i) {
65366548
nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T);
65376549
}
65386550
}
@@ -6549,28 +6561,28 @@ static __global__ void flash_attn_ext_f16(
65496561

65506562
// loop over the KV cache
65516563
// each simdgroup handles blocks of Q rows and C columns
6552-
for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) {
6564+
for (int ic = C*warp_id; ic < ne11; ic += C*num_warps) {
65536565
// Q*K^T
65546566
{
65556567
for (int cc = 0; cc < C/16; ++cc) {
65566568
half16x16_acc mqk[Q16];
6557-
for (int64_t j = 0; j < Q16; ++j) {
6569+
for (int j = 0; j < Q16; ++j) {
65586570
nvcuda::wmma::fill_fragment(mqk[j], 0);
65596571
}
65606572

65616573
const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13));
65626574

6563-
for (int64_t i = 0; i < D16; ++i) {
6575+
for (int i = 0; i < D16; ++i) {
65646576
half16x16_bT mk; // transposed key
65656577
nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half));
65666578

6567-
for (int64_t j = 0; j < Q16; ++j) {
6579+
for (int j = 0; j < Q16; ++j) {
65686580
nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]);
65696581
}
65706582
}
65716583

65726584
// mqk = mqk*scale + mask
6573-
for (int64_t j = 0; j < Q16; ++j) {
6585+
for (int j = 0; j < Q16; ++j) {
65746586
half16x16_a mqka;
65756587
half16x16_acc mm;
65766588
if(mp) {
@@ -6592,8 +6604,8 @@ static __global__ void flash_attn_ext_f16(
65926604

65936605
// online softmax
65946606
if (C == 32) {
6595-
for (int64_t j = 0; j < Q; ++j) {
6596-
const int64_t p = lane_id;
6607+
for (int j = 0; j < Q; ++j) {
6608+
const int p = lane_id;
65976609

65986610
const half m = M[j];
65996611
const half s = ss[j*T + p];
@@ -6615,10 +6627,10 @@ static __global__ void flash_attn_ext_f16(
66156627
ss[j*T + p] = vs;
66166628
}
66176629
} else {
6618-
for (int64_t j = 0; j < Q; ++j) {
6630+
for (int j = 0; j < Q; ++j) {
66196631
const half m = M[j];
66206632

6621-
for (int64_t p = lane_id; p < C; p += NW) {
6633+
for (int p = lane_id; p < C; p += NW) {
66226634
const half s = ss[j*T + p];
66236635

66246636
smax = __hmax(smax, s);
@@ -6638,7 +6650,7 @@ static __global__ void flash_attn_ext_f16(
66386650
// local sum
66396651
half ls = 0.0f;
66406652

6641-
for (int64_t p = lane_id; p < C; p += NW) {
6653+
for (int p = lane_id; p < C; p += NW) {
66426654
const half s = ss[j*T + p];
66436655

66446656
const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]);
@@ -6659,13 +6671,13 @@ static __global__ void flash_attn_ext_f16(
66596671
}
66606672

66616673
// O = diag(ms)*O
6662-
for (int64_t j = 0; j < Q16; ++j) {
6674+
for (int j = 0; j < Q16; ++j) {
66636675
half16x16_a mm;
66646676
half16x16_b lob;
66656677

66666678
nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
66676679

6668-
for (int64_t i = 0; i < D16; ++i) {
6680+
for (int i = 0; i < D16; ++i) {
66696681
// convert accumulator to matrix_b
66706682
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major);
66716683
nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T);
@@ -6684,17 +6696,17 @@ static __global__ void flash_attn_ext_f16(
66846696
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));
66856697

66866698
half16x16_b mk[D16];
6687-
for (int64_t i = 0; i < D16; ++i) {
6699+
for (int i = 0; i < D16; ++i) {
66886700
nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half));
66896701
}
66906702

66916703
half16x16_a mv[Q16];
6692-
for (int64_t j = 0; j < Q16; ++j) {
6704+
for (int j = 0; j < Q16; ++j) {
66936705
nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T);
66946706
}
66956707

6696-
for (int64_t j = 0; j < Q16; ++j) {
6697-
for (int64_t i = 0; i < D16; ++i) {
6708+
for (int j = 0; j < Q16; ++j) {
6709+
for (int i = 0; i < D16; ++i) {
66986710
nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]);
66996711
}
67006712
}
@@ -6703,7 +6715,7 @@ static __global__ void flash_attn_ext_f16(
67036715
}
67046716

67056717
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
6706-
for (int64_t j = 0; j < Q; ++j) {
6718+
for (int j = 0; j < Q; ++j) {
67076719
if (lane_id == 0) {
67086720
ss[j*T + 0] = S[j];
67096721
ss[j*T + 1] = M[j];
@@ -6712,16 +6724,16 @@ static __global__ void flash_attn_ext_f16(
67126724
}
67136725

67146726
// reduce the warps sequentially
6715-
for (int64_t sg = 1; sg < num_warps; ++sg) {
6727+
for (int sg = 1; sg < num_warps; ++sg) {
67166728
half S = __float2half(0.0f);
67176729
half M = __float2half(-INFINITY);
67186730

67196731
__syncthreads();
67206732

67216733
// each simdgroup stores its output to shared memory, reusing sq
67226734
if (warp_id == sg) {
6723-
for (int64_t j = 0; j < Q16; ++j) {
6724-
for (int64_t i = 0; i < D16; ++i) {
6735+
for (int j = 0; j < Q16; ++j) {
6736+
for (int i = 0; i < D16; ++i) {
67256737
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
67266738
}
67276739
}
@@ -6731,7 +6743,7 @@ static __global__ void flash_attn_ext_f16(
67316743

67326744
// the first simdgroup accumulates the results from the other simdgroups
67336745
if (warp_id == 0) {
6734-
for (int64_t j = 0; j < Q; ++j) {
6746+
for (int j = 0; j < Q; ++j) {
67356747
const half S0 = ss[j*T + 0];
67366748
const half S1 = ss[j*T + sg*SH + 0];
67376749

@@ -6755,7 +6767,7 @@ static __global__ void flash_attn_ext_f16(
67556767
}
67566768

67576769
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
6758-
for (int64_t j = 0; j < Q16; ++j) {
6770+
for (int j = 0; j < Q16; ++j) {
67596771
half16x16_a ms0;
67606772
half16x16_a ms1;
67616773
half16x16_b t;
@@ -6764,7 +6776,7 @@ static __global__ void flash_attn_ext_f16(
67646776
nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T);
67656777
nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T);
67666778

6767-
for (int64_t i = 0; i < D16; ++i) {
6779+
for (int i = 0; i < D16; ++i) {
67686780
nvcuda::wmma::fill_fragment(t2, 0.0);
67696781
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
67706782
nvcuda::wmma::mma_sync(t2, ms1, t, t2);
@@ -6781,19 +6793,19 @@ static __global__ void flash_attn_ext_f16(
67816793

67826794
// store result to shared memory (reuse sq)
67836795
if (warp_id == 0) {
6784-
for (int64_t j = 0; j < Q16; ++j) {
6785-
for (int64_t i = 0; i < D16; ++i) {
6796+
for (int j = 0; j < Q16; ++j) {
6797+
for (int i = 0; i < D16; ++i) {
67866798
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
67876799
}
67886800
}
67896801
}
67906802

67916803
// final rescale with 1/S and store to global memory
67926804
if (warp_id == 0) {
6793-
for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) {
6805+
for (int j = 0; j < Q && iq1 + j < ne01; ++j) {
67946806
const half S = ss[j*T + 0];
67956807

6796-
for (int64_t i = lane_id; i < D; i += NW) {
6808+
for (int i = lane_id; i < D; i += NW) {
67976809
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S);
67986810
}
67996811
}

0 commit comments

Comments
 (0)
0