forked from ggml-org/llama.cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
WIP: Flash Attention implementation (forward + backward) #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 1 commit
Commits
Show all changes
72 commits
Select commit
Hold shift + click to select a range
f7bcfb0
cuda: add flash attention + test
FSSRepo e53de28
fix compilation
FSSRepo a1c004e
ggml : add ggml_flash_attn_ext API
ggerganov fa7ebcc
ggml : fix GQA support in ggml_flash_attn_ext
ggerganov 09db1a7
Merge branch 'gg/flash-attn' of https://github.com/ggerganov/llama.cp…
FSSRepo fded2e6
apply suggestions
FSSRepo c3cdfff
Merge branch 'master' into gg/flash-attn
ggerganov a9681fe
ggml : online attention (CPU)
ggerganov 1173f49
metal : initial implementation
ggerganov 528da75
metal : f16 precision
ggerganov 52ae085
metal : reduce branches
ggerganov b973258
metal : specialize for head size
ggerganov 8cde449
wip : 8 rows per simd group
ggerganov f31955f
wip : 4 rows per simd group
ggerganov a4b6341
wip : template for rows per warp
ggerganov 77d08f3
metal : parallelize across KV size
ggerganov 17720fa
metal : parallel reduce across heads
ggerganov a689b02
Merge branch 'gg/flash-attn' of https://github.com/ggerganov/llama.cp…
FSSRepo 6374bc5
cuda: port metal version flash_attn_ext
FSSRepo 6416821
fix equivalent fp16 math functions, compiler error 'undefined'
FSSRepo 972c2ad
use half2 instead half4
FSSRepo 0fc36d8
match to metal impl
FSSRepo 1446a12
metal : efficient flash_attn_f16 implementation
ggerganov d917746
metal : avoid redundant loads of the attention
ggerganov 432ad04
metal : scale and mask in matrix form
ggerganov 40ea8cd
metal : fix comment
ggerganov 78da338
Merge branch 'gg/flash-attn' of https://github.com/ggerganov/llama.cp…
FSSRepo f9ca5dc
llama : avoid ggml_cast, use F32 query
ggerganov 6e7cb0e
update implementation
FSSRepo 6fea843
metal : add parallel reduce version (disabled)
ggerganov 0a481fe
integrate tensor cores
FSSRepo 7cea973
Merge branch 'gg/flash-attn' of https://github.com/ggerganov/llama.cp…
FSSRepo 2455a8d
update impl
FSSRepo b3dd7d9
Merge branch 'master' into gg/flash-attn
ggerganov 77f6976
metal : move output into local memory + optimize
ggerganov ecc466a
metal : add tests, fix scaling, support C > 32
ggerganov 3a428a1
metal : improve precision
ggerganov 8612864
ggml : fix f16 mad
ggerganov 0ad44ba
Merge branch 'master' into gg/flash-attn
ggerganov 134c81c
metal : minor
ggerganov 1db22d7
metal : support Q > 8
ggerganov 4794821
tests : add ATTN tests
ggerganov abeaf0d
metal : disable buffer allocation logs
ggerganov c6c1132
tests : more
ggerganov 5fcb9c1
metal : faster inner loop for C == 32
ggerganov a1d5a12
fix compiler error
FSSRepo 7980178
Merge branch 'gg/flash-attn' of https://github.com/ggerganov/llama.cp…
FSSRepo d073e4f
metal : fix array initialization
ggerganov 78df552
tests : ifdef
ggerganov 3d03bcb
Merge branch 'master' into gg/flash-attn
ggerganov 3b0f74b
latest kernel update, wrong values
FSSRepo 2ddc9bb
Merge branch 'master' into gg/flash-attn
ggerganov b1479df
fix kernel
FSSRepo 8ad92dc
ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
ggerganov 0afe47f
fix naive implementation
FSSRepo 3df0b8d
Merge branch 'gg/flash-attn' of https://github.com/ggerganov/llama.cp…
FSSRepo fd878f7
cuda: mask as fp16
FSSRepo 71b69aa
cuda : fix flash_attn kernel to produce same results as CPU
ggerganov 2c04bee
cuda : avoid extra QxQ matrix in shared memory
ggerganov 9a5c2a1
cuda : switch to F16 scalars + tune warps for RTX 2060
ggerganov ac26f27
cuda : increase C to 128 for better performance
ggerganov 43f7156
Merge pull request #3 from ggerganov/flash-attn-cuda
FSSRepo 9240a84
fix mask nullptr
FSSRepo 8d7a606
don't require LLAMA_CUDA_F16 to compile
FSSRepo 19e0b8e
#ifdef -> #if + fix check -inf
FSSRepo cae985c
cmake: remove unused changes
FSSRepo 53621e3
refactor flash_attn function + improve tests
FSSRepo 674d5ac
unroll 2 loops, int64_t -> int, 309 µs
JohannesGaessler 8b51ab4
Merge pull request #4 from Pints-App/jg/flash-attn-cuda
FSSRepo a1f9ffe
bring optimizations from gg/flash-attn
FSSRepo ba7699d
Merge branch 'flash-attn-cuda' of https://github.com/Pints-App/llama.…
FSSRepo f659f57
fix merge conflicts
FSSRepo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10000
Prev
Previous commit
Next
Next commit
cuda: port metal version flash_attn_ext
- Loading branch information
commit 6374bc5779784de48fd79351942f8b53589eff7e
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -937,6 +937,7 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr | |
if (lane_id == 0) { | ||
s_sum[warp_id] = tmp; | ||
} | ||
|
||
__syncthreads(); | ||
tmp = s_sum[lane_id]; | ||
tmp = warp_reduce_sum(tmp); | ||
|
@@ -6106,6 +6107,211 @@ static __global__ void flash_attn_f32( | |
} | ||
} | ||
|
||
struct __align__(8) half4 { | ||
half x; | ||
half y; | ||
half z; | ||
half w; | ||
}; | ||
|
||
// based on metal version | ||
template<int D, int R> // head size, rows per block | ||
static __global__ void flash_attn_ext_f16( | ||
const char* __restrict__ q, | ||
const char* __restrict__ k, | ||
const char* __restrict__ v, | ||
const char* __restrict__ mask, | ||
float* __restrict__ kqv, | ||
float scale, | ||
int ne00, | ||
int ne01, | ||
int ne02, | ||
int ne03, | ||
int ne10, | ||
int ne11, | ||
int ne12, | ||
int ne13, | ||
int ne31, | ||
int nb31, | ||
int nb01, | ||
int nb02, | ||
int nb03, | ||
int nb11, | ||
int nb12, | ||
int nb13, | ||
int ne0, | ||
int ne1, | ||
int ne2, | ||
int ne3) { | ||
int warp_id = threadIdx.x / WARP_SIZE; | ||
int lane_id = threadIdx.x % WARP_SIZE; | ||
|
||
const int nwraps = blockDim.y; // number of warps | ||
const int tph = WARP_SIZE / R; // threads per head | ||
const int iq3 = blockIdx.z; | ||
const int iq2 = blockIdx.y * R + lane_id / tph; | ||
const int iq1 = blockIdx.x; | ||
|
||
if(iq2 >= ne02) { | ||
return; | ||
} | ||
|
||
// broadcast | ||
const int rk2 = ne02 / ne12; | ||
const int rk3 = ne03 / ne13; | ||
// assume the same K and V shape | ||
// const int rv2 = ne02 / ne12; | ||
// const int rv3 = ne03 / ne13; | ||
|
||
// kv indices | ||
const int ik2 = iq2 / rk2; | ||
const int ik3 = iq3 / rk3; | ||
const int iv2 = iq2 / rv2; | ||
const int iv3 = iq3 / rv3; | ||
|
||
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; | ||
|
||
const float * mp = mask ? mask + (ir % ne31)*nb31 : nullptr; | ||
|
||
extern __shared__ char shmem__[]; | ||
|
||
half4* pq4 = (half4*)shmem__; | ||
half4* ps4 = (half4*)(shmem__ + warp_id * (R * D + 32) + 1*R*D); | ||
half* ss = (half *)(shmem__ + warp_id * (R * D + 32) + 2*R*D); | ||
|
||
const int tiih = lane_id % tph; // thread index in head | ||
const int hiisg = lane_id / tph; // head index in warp | ||
|
||
const int D4 = D/4; | ||
|
||
// load R heads from Q to shared memory | ||
for (int64_t i = 0; i < D4/tph; ++i) { | ||
if (warp_id == 0) { | ||
pq4[hiisg*D4 + tph*i + tiih] = (const half4*)((const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03))[tph*i + tiih]; | ||
} | ||
|
||
ps4[hiisg*D4 + tph*i + tiih] = 0.0h; | ||
} | ||
__syncthreads(); | ||
|
||
half S = 0.0h; | ||
half M = -INFINITY; | ||
|
||
for (int64_t ic = warp_id; ic < ne11; ic += nwraps) { | ||
const half mv = mp ? mp[ic] : 0.0h; | ||
if (mv == -INFINITY) { | ||
continue; | ||
} | ||
|
||
const half4 * pk4 = (const half4 *) ((char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); | ||
const half4 * pv4 = (const half4 *) ((char *) v + (ic*nb11 + iv2*nb12 + iv3*nb13)); // assumes V same shape of K | ||
|
||
half4 s4 = 0.0h; | ||
|
||
#pragma unroll | ||
for (int i = 0; i < D4/tph; ++i) { | ||
s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih]; | ||
} | ||
|
||
ss[hiisg*tph + tiih] = (s4.x + s4.y + s4.z + s4.w); | ||
|
||
__syncthreads(); | ||
|
||
if (tiih == 0) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is bad in terms of performance. Not only is |
||
half s = 0.0h; | ||
|
||
#pragma unroll | ||
for (int i = 0; i < tph; ++i) { | ||
s += ss[hiisg*tph + i]; | ||
} | ||
|
||
s = s*scale + mv; | ||
|
||
const half m = M; | ||
|
||
M = max(M, s); | ||
|
||
const half ms = exp(m - M); | ||
const half vs = exp(s - M); | ||
|
||
S = S*ms + vs; | ||
|
||
ss[2*hiisg + 0] = ms; | ||
ss[2*hiisg + 1] = vs; | ||
} | ||
|
||
__syncthreads(); | ||
|
||
const half ms = ss[2*hiisg + 0]; | ||
const half vs = ss[2*hiisg + 1]; | ||
|
||
#pragma unroll | ||
for (int i = 0; i < D4/tph; ++i) { | ||
ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs; | ||
} | ||
} | ||
|
||
if (tiih == 0) { | ||
ss[2*hiisg + 0] = S; | ||
ss[2*hiisg + 1] = M; | ||
} | ||
|
||
__syncthreads(); | ||
|
||
// reduce the warps | ||
if (warp_id == 0) { | ||
for (int sg = 1; sg < nwraps; ++sg) { | ||
const half S0 = ss[ 2*hiisg + 0]; | ||
const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0]; | ||
|
||
const half M0 = ss[ 2*hiisg + 1]; | ||
const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1]; | ||
|
||
M = max(M0, M1); | ||
|
||
const half ms0 = exp(M0 - M); | ||
const half ms1 = exp(M1 - M); | ||
|
||
S = S0*ms0 + S1*ms1; | ||
|
||
if (tiih == 0) { | ||
ss[2*hiisg + 0] = S; | ||
ss[2*hiisg + 1] = M; | ||
} | ||
|
||
for (int i = 0; i < D4/tph; ++i) { | ||
ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih]*ms1; | ||
} | ||
} | ||
|
||
for (int i = 0; i < D4/tph; ++i) { | ||
ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]/S; | ||
} | ||
} | ||
|
||
__syncthreads(); | ||
|
||
// dst indices | ||
const int i1 = iq1; | ||
const int i2 = iq2; | ||
const int i3 = iq3; | ||
|
||
float4 * dst4 = (float4 *) kqv; | ||
|
||
if (warp_id == 0) { | ||
for (int i = 0; i < D4/tph; ++i) { | ||
float4 dst_ = | ||
dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih]; | ||
half4 src_ = ps4[hiisg*D4 + tph*i + tiih]; | ||
dst_.x = __half2float(src_.x); | ||
dst_.y = __half2float(src_.y); | ||
dst_.z = __half2float(src_.z); | ||
dst_.w = __half2float(src_.w); | ||
} | ||
} | ||
} | ||
|
||
|
||
template<int qk, int qr, dequantize_kernel_t dq> | ||
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, | ||
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { | ||
|
@@ -10071,6 +10277,98 @@ inline void ggml_cuda_flash_attn(const ggml_tensor * Q, const ggml_tensor * K, c | |
KQ_scale, d_head, sequence_length, num_heads, main_stream); | ||
} | ||
|
||
|
||
inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, const ggml_tensor * mask, ggml_tensor * KQV) { | ||
GGML_ASSERT(Q->type == GGML_TYPE_F16); | ||
GGML_ASSERT(K->type == GGML_TYPE_F16); | ||
GGML_ASSERT(V->type == GGML_TYPE_F16); | ||
GGML_ASSERT(mask->type == GGML_TYPE_F32); | ||
GGML_ASSERT(KQV->type == GGML_TYPE_F32); | ||
|
||
GGML_ASSERT(Q->backend == GGML_BACKEND_GPU); | ||
GGML_ASSERT(K->backend == GGML_BACKEND_GPU); | ||
GGML_ASSERT(V->backend == GGML_BACKEND_GPU); | ||
GGML_ASSERT(mask->backend == GGML_BACKEND_GPU); | ||
GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU); | ||
|
||
ggml_cuda_set_device(g_main_device); | ||
const cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; | ||
|
||
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) Q->extra; | ||
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) K->extra; | ||
ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) V->extra; | ||
ggml_tensor_extra_gpu * src3_extra = (ggml_tensor_extra_gpu *) mask->extra; | ||
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) KQV->extra; | ||
|
||
float scale; | ||
memcpy(&scale, KQV->op_params, sizeof(float)); | ||
|
||
const int nwarps = 32; | ||
const int nhpw = 2; // heads per warp | ||
|
||
dim3 blocks_num(Q->ne[1], (Q->ne[2] + nhpw - 1)/(nhpw), Q->ne[3]); | ||
dim3 block_dim(32, nwarps, 1); | ||
|
||
int shmem = (nhpw*Q->ne[0] + nwarps*(nhpw*Q->ne[0] + 32))*(sizeof(float)/2); | ||
|
||
switch (Q->ne[0]) | ||
{ | ||
case 64: | ||
flash_attn_ext_f16<64, 2> | ||
<<<blocks_num, block_dim, shmem, main_stream>>> ( | ||
(const char *) src0_extra->data_device[g_main_device], // Query | ||
(const char *) src1_extra->data_device[g_main_device], // Key | ||
(const char *) src2_extra->data_device[g_main_device], // Value | ||
(const char *) src3_extra->data_device[g_main_device], // Mask | ||
(float *) dst_extra->data_device[g_main_device], // dst | ||
scale, | ||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], | ||
K->ne[0], K->ne[1], K->ne[2], K->ne[3], | ||
mask->ne[1], mask->nb[1], | ||
Q->nb[1], Q->nb[2], Q->nb[3], | ||
K->nb[1], K->nb[2], K->nb[3], | ||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] | ||
); | ||
break; | ||
case 80: | ||
flash_attn_ext_f16<80, 2> | ||
<<<blocks_num, block_dim, shmem, main_stream>>> ( | ||
(const char *) src0_extra->data_device[g_main_device], // Query | ||
(const char *) src1_extra->data_device[g_main_device], // Key | ||
(const char *) src2_extra->data_device[g_main_device], // Value | ||
(const char *) src3_extra->data_device[g_main_device], // Mask | ||
(float *) dst_extra->data_device[g_main_device], // dst | ||
scale, | ||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], | ||
K->ne[0], K->ne[1], K->ne[2], K->ne[3], | ||
mask->ne[1], mask->nb[1], | ||
Q->nb[1], Q->nb[2], Q->nb[3], | ||
K->nb[1], K->nb[2], K->nb[3], | ||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] | ||
); | ||
break; | ||
case 128: | ||
flash_attn_ext_f16<128, 2> | ||
<<<blocks_num, block_dim, shmem, main_stream>>> ( | ||
(const char *) src0_extra->data_device[g_main_device], // Query | ||
(const char *) src1_extra->data_device[g_main_device], // Key | ||
(const char *) src2_extra->data_device[g_main_device], // Value | ||
(const char *) src3_extra->data_device[g_main_device], // Mask | ||
(float *) dst_extra->data_device[g_main_device], // dst | ||
scale, | ||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], | ||
K->ne[0], K->ne[1], K->ne[2], K->ne[3], | ||
mask->ne[1], mask->nb[1], | ||
Q->nb[1], Q->nb[2], Q->nb[3], | ||
K->nb[1], K->nb[2], K->nb[3], | ||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] | ||
); | ||
break; | ||
default: | ||
break; | ||
} | ||
} | ||
|
||
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | ||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale); | ||
} | ||
|
@@ -10341,6 +10639,8 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st | |
break; | ||
case GGML_OP_FLASH_ATTN: | ||
break; | ||
case GGML_OP_FLASH_ATTN_EXT: | ||
break; | ||
default: | ||
return false; | ||
} | ||
|
@@ -10357,7 +10657,9 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st | |
} | ||
if(tensor->op == GGML_OP_FLASH_ATTN) { | ||
ggml_cuda_flash_attn(tensor->src[0], tensor->src[1], tensor->src[2], tensor); | ||
} else { | ||
} else if(tensor->op == GGML_OP_FLASH_ATTN_EXT) { | ||
ggml_cuda_flash_attn_ext(tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); | ||
} else { | ||
func(tensor->src[0], tensor->src[1], tensor); | ||
} | ||
return true; | ||
|
@@ -11175,6 +11477,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons | |
case GGML_OP_UPSCALE: | ||
case GGML_OP_PAD: | ||
case GGML_OP_LEAKY_RELU: | ||
case GGML_OP_FLASH_ATTN_EXT: | ||
return true; | ||
default: | ||
return false; | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rewrite to loop with
ic = ic0 + warp_id
.