8000 metal: implement flash attention kernel for quantized KV cache by FanShupei · Pull Request #9735 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

metal: implement flash attention kernel for quantized KV cache #9735

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[metal-kernel] add flash_attn_ext_scalar_f16 implementation
  • Loading branch information
FanShupei committed Oct 3, 2024
commit 9e62e7e10e493db1d55cc0b531f4f44a44ea4a8c
288 changes: 288 additions & 0 deletions ggml/src/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -2799,6 +2799,294 @@ kernel void kernel_flash_attn_ext_vec_f16(
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;

half dequantize_load_f16(device const half *xb, short il) {
return xb[il];
}

half dequantize_load_q8_0(device const block_q8_0 *xb, short il) {
device const block_q8_0 *xb_ = &xb[il / QK8_0];
return xb_->d * xb_->qs[il % QK8_0];
}

template<typename block_q, half (*dequantize_load)(device const block_q* xb, short il), int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
kernel void kernel_flash_attn_ext_scalar_f16(
device const char * q,
device const char * k,
device const char * v,
device const char * mask,
device float * dst,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
constant uint64_t & nb21,
constant uint64_t & nb22,
constant uint64_t & nb23,
constant uint64_t & nb31,
constant int64_t & ne1,
constant int64_t & ne2,
constant float & scale,
constant float & max_bias,
constant float & m0,
constant float & m1,
constant uint32_t & n_head_log2,
constant float & logit_softcap,
threadgroup half * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const short nsg = ntg.y; // number of simdgroups

const short iq3 = tgpig[2];
const short iq2 = tgpig[1];
const short iq1 = tgpig[0];

const short NW = N_SIMDWIDTH;
const short SH = (C + Q); // shared memory per simdgroup in (half)

const short T = D + 2*nsg*SH; // shared memory size per query in (half)

float slope = 1.0f;

// ALiBi
if (max_bias > 0.0f) {
const uint32_t h = iq2;

const float base = h < n_head_log2 ? m0 : m1;
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;

slope = pow(base, exp);
}

threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
threadgroup half * sr = (threadgroup half *) (shared + sgitg*D + 1*T); // scratch buffer for the results

// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
half lo[D/NW];

// load heads from Q to shared memory
device const float * q_ = (device const float *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));

for (short i = tiisg; i < D; i += NW) {
if (iq1 < ne01) {
sq[i] = (half) q_[i];
} else {
sq[i] = 0.0h;
}
}

// zero out lo
for (short i = tiisg; i < D; i += NW) {
lo[i/NW] = 0.0h;
}

// zero out shared memory SH
for (short i = tiisg; i < SH; i += NW) {
ss[i] = 0.0h;
}

< 10000 /td> threadgroup_barrier(mem_flags::mem_threadgroup);

{
float S = { 0.0h };
float M = { -FLT_MAX/2 };

// assume K and V are same shape
const short ne22 = ne12;
const short ne23 = ne13;

// broadcast
const short rk2 = ne02/ne12;
const short rk3 = ne03/ne13;

const short rv2 = ne02/ne22;
const short rv3 = ne03/ne23;

// k indices
const short ik2 = iq2 / rk2;
const short ik3 = iq3 / rk3;

// v indices
const short iv2 = iq2 / rv2;
const short iv3 = iq3 / rv3;

// load the queries from shared memory into local memory
half mq[D];

for (short ii = 0; ii < D; ii += NW) {
short i = ii + tiisg;
mq[i] = sq[i];
}

// pointer to the mask
device const half * mp = (device const half *) (mask + iq1*nb31);

// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
const int ic = ic0 + C*sgitg;
if (ic >= ne11) {
break;
}

// Q*K^T
{
// #pragma unroll
for (short cc = 0; cc < C; ++cc) {
float mqk = 0.0;

device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + cc)*nb11 + ik2*nb12 + ik3*nb13));

#pragma unroll
for (short ii = 0; ii < D; ii += NW) {
const short i = ii + tiisg;
mqk += mq[i] * dequantize_load(pk, i);
}

// reduce the results from the threads in the simdgroup
mqk += simd_shuffle_down(mqk, 16);
mqk += simd_shuffle_down(mqk, 8);
mqk += simd_shuffle_down(mqk, 4);
mqk += simd_shuffle_down(mqk, 2);
mqk += simd_shuffle_down(mqk, 1);

// mqk = mqk*scale + mask*slope
if (tiisg == 0) {
mqk *= scale;

if (logit_softcap != 0.0f) {
mqk = logit_softcap*precise::tanh(mqk);
}

if (mask != q) {
mqk += (mp[ic + cc])*slope;
}

ss[cc] = mqk;
}
}
}

// online softmax
{
const short p = tiisg;

const float m = M;
const float s = ss[p];

M = simd_max(max(M, s));

const float ms = exp(m - M);
const float vs = exp(s - M);

S = S*ms + simd_sum(vs);

// the P matrix from the paper (Q rows, C columns)
ss[p] = vs;

// O = diag(ms)*O
#pragma unroll
for (short ii = 0; ii < D; ii += NW) {
const short i = ii + tiisg;
lo[i/NW] *= ms;
}
}

// O = O + (Q*K^T)*V
{
// #pragma unroll
for (short cc = 0; cc < C; ++cc) {
device const block_q * pv = (device const block_q *) ((device const char *) v + ((ic + cc)*nb21 + iv2*nb22 + iv3*nb23));

#pragma unroll
for (short ii = 0; ii < D; ii += NW) {
const short i = ii + tiisg;

lo[i/NW] += dequantize_load(pv, i) * ss[cc];
}
}
}

}

// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
if (tiisg == 0) {
ss[0] = S;
ss[1] = M;
}
}

// store results to shared memory
for (short ii = 0; ii < D; ii += NW) {
short i = ii + tiisg;
sr[i] = lo[ii/NW];
}

threadgroup_barrier(mem_flags::mem_threadgroup);

// parallel reduce
for (short r = nsg/2; r > 0; r >>= 1) {
if (sgitg < r) {
const float S0 = ss[ 0];
const float S1 = ss[r*SH + 0];

const float M0 = ss[ 1];
const float M1 = ss[r*SH + 1];

const float M = max(M0, M1);

const float ms0 = exp(M0 - M);
const float ms1 = exp(M1 - M);

const float S = S0*ms0 + S1*ms1;

if (tiisg == 0) {
ss[0] = S;
ss[1] = M;
}

// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
for (short ii = 0; ii < D; ii += NW) {
short i = ii + tiisg;
sr[i] = sr[i]*ms0 + sr[i + r*D]*ms1;
}
}

threadgroup_barrier(mem_flags::mem_threadgroup);
}

// final rescale with 1/S and store to global memory
if (sgitg == 0) {
const float S = ss[0];

for (short ii = 0; ii < D; ii += NW) {
short i = ii + tiisg;
dst[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D + i] = sr[i]/S;
}
}
}

template [[host_name("kernel_flash_attn_ext_scalar_f16_h32")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_scalar_f16<half, dequantize_load_f16, 32>;
template [[host_name("kernel_flash_attn_ext_scalar_f16_h64")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_scalar_f16<half, dequantize_load_f16, 64>;
template [[host_name("kernel_flash_attn_ext_scalar_f16_h96")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_scalar_f16<half, dequantize_load_f16, 96>;
template [[host_name("kernel_flash_attn_ext_scalar_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_scalar_f16<half, dequantize_load_f16, 128>;

template [[host_name("kernel_flash_attn_ext_scalar_q8_0_h32")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_scalar_f16<block_q8_0, dequantize_load_q8_0, 32>;
template [[host_name("kernel_flash_attn_ext_scalar_q8_0_h64")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_scalar_f16<block_q8_0, dequantize_load_q8_0, 64>;
template [[host_name("kernel_flash_attn_ext_scalar_q8_0_h96")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_scalar_f16<block_q8_0, dequantize_load_q8_0, 96>;
template [[host_name("kernel_flash_attn_ext_scalar_q8_0_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_scalar_f16<block_q8_0, dequantize_load_q8_0, 128>;

template<typename T0, typename T1>
kernel void kernel_cpy(
device const void * src0,
Expand Down
0