8000 WIP: Flash Attention implementation (forward + backward) by FSSRepo · Pull Request #1 · Pints-AI/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content
8000

WIP: Flash Attention implementation (forward + backward) #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 72 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
f7bcfb0
cuda: add flash attention + test
FSSRepo Jan 17, 2024
e53de28
fix compilation
FSSRepo Jan 18, 2024
a1c004e
ggml : add ggml_flash_attn_ext API
ggerganov Jan 18, 2024
fa7ebcc
ggml : fix GQA support in ggml_flash_attn_ext
ggerganov Jan 19, 2024
09db1a7
Merge branch 'gg/flash-attn' of https://github.com/ggerganov/llama.cp…
FSSRepo Jan 19, 2024
fded2e6
apply suggestions
FSSRepo Jan 20, 2024
c3cdfff
Merge branch 'master' into gg/flash-attn
ggerganov Jan 20, 2024
a9681fe
ggml : online attention (CPU)
ggerganov Jan 20, 2024
1173f49
metal : initial implementation
ggerganov Jan 20, 2024
528da75
metal : f16 precision
ggerganov Jan 21, 2024
52ae085
metal : reduce branches
ggerganov Jan 21, 2024
b973258
metal : specialize for head size
ggerganov Jan 21, 2024
8cde449
wip : 8 rows per simd group
ggerganov Jan 21, 2024
f31955f
wip : 4 rows per simd group
ggerganov Jan 21, 2024
a4b6341
wip : template for rows per warp
ggerganov Jan 21, 2024
8000 77d08f3
metal : parallelize across KV size
ggerganov Jan 21, 2024
17720fa
metal : parallel reduce across heads
ggerganov Jan 21, 2024
a689b02
Merge branch 'gg/flash-attn' of https://github.com/ggerganov/llama.cp…
FSSRepo Jan 23, 2024
6374bc5
cuda: port metal version flash_attn_ext
FSSRepo Jan 23, 2024
6416821
fix equivalent fp16 math functions, compiler error 'undefined'
FSSRepo Jan 24, 2024
972c2ad
use half2 instead half4
FSSRepo Jan 24, 2024
0fc36d8
match to metal impl
FSSRepo Jan 24, 2024
1446a12
metal : efficient flash_attn_f16 implementation
ggerganov Jan 23, 2024
d917746
metal : avoid redundant loads of the attention
ggerganov Jan 25, 2024
432ad04
metal : scale and mask in matrix form
ggerganov Jan 25, 2024
40ea8cd
metal : fix comment
ggerganov Jan 25, 2024
78da338
Merge branch 'gg/flash-attn' of https://github.com/ggerganov/llama.cp…
FSSRepo Jan 25, 2024
f9ca5dc
llama : avoid ggml_cast, use F32 query
ggerganov Jan 25, 2024
6e7cb0e
update implementation
FSSRepo Jan 25, 2024
6fea843
metal : add parallel reduce version (disabled)
ggerganov Jan 25, 2024
0a481fe
integrate tensor cores
FSSRepo Jan 27, 2024
7cea973
Merge branch 'gg/flash-attn' of https://github.com/ggerganov/llama.cp…
FSSRepo Jan 27, 2024
2455a8d
update impl
FSSRepo Jan 27, 2024
b3dd7d9
Merge branch 'master' into gg/flash-attn
ggerganov Jan 28, 2024
77f6976
metal : move output into local memory + optimize
ggerganov Jan 28, 2024
ecc466a
metal : add tests, fix scaling, support C > 32
ggerganov Jan 28, 2024
3a428a1
metal : improve precision
ggerganov Jan 28, 2024
8612864
ggml : fix f16 mad
ggerganov Jan 28, 2024
0ad44ba
Merge branch 'master' into gg/flash-attn
ggerganov Jan 28, 2024
134c81c
metal : minor
ggerganov Jan 28, 2024
1db22d7
metal : support Q > 8
ggerganov Jan 28, 2024
4794821
tests : add ATTN tests
ggerganov Jan 29, 2024
abeaf0d
metal : disable buffer allocation logs
ggerganov Jan 29, 2024
c6c1132
tests : more
ggerganov Jan 29, 2024
5fcb9c1
metal : faster inner loop for C == 32
ggerganov Jan 29, 2024
a1d5a12
fix compiler error
FSSRepo Jan 29, 2024
7980178
Merge branch 'gg/flash-attn' of https://github.com/ggerganov/llama.cp…
FSSRepo Jan 29, 2024
d073e4f
metal : fix array initialization
ggerganov Jan 30, 2024
78df552
tests : ifdef
ggerganov Jan 30, 2024
3d03bcb
Merge branch 'master' into gg/flash-attn
ggerganov Jan 30, 2024
3b0f74b
latest kernel update, wrong values
FSSRepo Jan 30, 2024
2ddc9bb
Merge branch 'master' into gg/flash-attn
ggerganov Jan 31, 2024
b1479df
fix kernel
FSSRepo Jan 31, 2024
8ad92dc
ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
ggerganov Jan 31, 2024
0afe47f
fix naive implementation
FSSRepo Jan 31, 2024
3df0b8d
Merge branch 'gg/flash-attn' of https://github.com/ggerganov/llama.cp…
FSSRepo Jan 31, 2024
fd878f7
cuda: mask as fp16
FSSRepo Jan 31, 2024
71b69aa
cuda : fix flash_attn kernel to produce same results as CPU 8000
ggerganov Feb 1, 2024
2c04bee
cuda : avoid extra QxQ matrix in shared memory
ggerganov Feb 1, 2024
9a5c2a1
cuda : switch to F16 scalars + tune warps for RTX 2060
ggerganov Feb 1, 2024
ac26f27
cuda : increase C to 128 for better performance
ggerganov Feb 1, 2024
43f7156
Merge pull request #3 from ggerganov/flash-attn-cuda
FSSRepo Feb 1, 2024
9240a84
fix mask nullptr
FSSRepo Feb 1, 2024
8d7a606
don't require LLAMA_CUDA_F16 to compile
FSSRepo Feb 1, 2024
19e0b8e
#ifdef -> #if + fix check -inf
FSSRepo Feb 1, 2024
cae985c
cmake: remove unused changes
FSSRepo Feb 1, 2024
53621e3
refactor flash_attn function + improve tests
FSSRepo Feb 1, 2024
674d5ac
unroll 2 loops, int64_t -> int, 309 µs
JohannesGaessler Feb 3, 2024
8b51ab4
Merge pull request #4 from Pints-App/jg/flash-attn-cuda
FSSRepo Feb 3, 2024
a1f9ffe
bring optimizations from gg/flash-attn
FSSRepo Feb 3, 2024
ba7699d
Merge branch 'flash-attn-cuda' of https://github.com/Pints-App/llama.…
FSSRepo Feb 3, 2024
f659f57
fix merge conflicts
FSSRepo Feb 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
apply suggestions
  • Loading branch information
FSSRepo committed Jan 20, 2024
commit fded2e6a11bb600e04fec8714ab9165bda7724f8
91 changes: 63 additions & 28 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5989,38 +5989,55 @@ static __global__ void im2col_f32_f16(

#define CUDA_FLASH_ATTENTION_BLOCK_SIZE 256

template<int block_size>
static __global__ void flash_attn_f32(const float* q, const float* k,const float* v, float* dst, float kq_scale,
int d_head, int seq_len, int num_heads) {
template<int block_size, int k_seq_len>
static __global__ void flash_attn_f32(
const float* __restrict__ q,
const float* __restrict__ k,
const float* __restrict__ v,
float* __restrict__ kqv,
float kq_scale,
int head_dim, int seq_len, int num_heads) {
const int head = blockIdx.x / seq_len;
const int head_size = d_head * seq_len;
const int head_size = head_dim * seq_len;
const int s = blockIdx.x % seq_len;
const int tid = threadIdx.x;

extern __shared__ char work_data[];
float* S = (float*)work_data; // theorical sequent length: 12848, due memory per block limit
float* warp_data = (float*)(work_data + seq_len * sizeof(float));
extern __shared__ char shmem__[];
float* S = (float*)shmem__;
float* warp_data = (float*)(shmem__ + seq_len * sizeof(float));

// QK^T
for(int is = tid; is < seq_len; is += block_size) {
#pragma unroll
for(int is0 = 0; is0 < k_seq_len; is0 += block_size) {
const int is = threadIdx.x + is0;
if(is >= seq_len) {
break;
}

const int key_offset = is * head_dim + head * head_size;
const int query_offset = s * head_dim + head * head_size;

S[is] = 0.0f;
int key_offset = is * d_head + head * head_size;
int query_offset = s * d_head + head * head_size;
for(int d = 0; d < d_head; d++) {
for(int d = 0; d < head_dim; d++) {
S[is] += k[key_offset + d] * q[query_offset + d];
}
S[is] *= kq_scale;
}

__syncthreads();

float max_val = -INFINITY;
// get the max
for(int is = tid; is < seq_len; is += block_size) {
#pragma unroll
for(int is0 = 0; is0 < k_seq_len; is0 += block_size) {
const int is = threadIdx.x + is0;
if(is >= seq_len) {
break;
}

max_val = fmaxf(max_val , S[is]);
}

max_val = warp_reduce_max(max_val);

{ // get max from all threads
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
Expand All @@ -6034,14 +6051,20 @@ static __global__ void flash_attn_f32(const float* q, const float* k,const float

// softmax(QK^T)
float sum = 0.0f;
for(int is = tid; is < seq_len;is += block_size) {
const float val = expf(S[is] - max_val);
S[is] = val;
sum += val;
#pragma unroll
for(int is0 = 0; is0 < k_seq_len; is0 += block_size) {
const int is = threadIdx.x + is0;
if(is >= seq_len) {
break;
}

S[is] = expf(S[is] - max_val);
sum += S[is];
}
__syncthreads();

sum = warp_reduce_sum(sum);
{ // sum partials
{ // softmax sum partials
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
Expand All @@ -6053,19 +6076,31 @@ static __global__ void flash_attn_f32(const float* q, const float* k,const float
}

float inv_sum = 1.0f / sum;
for(int is = tid; is < seq_len; is += block_size) {
#pragma unroll
for(int is0 = 0; is0 < k_seq_len; is0 += block_size) {
const int is = threadIdx.x + is0;
if(is >= seq_len) {
break;
}

S[is] *= inv_sum;
}

__syncthreads();

// softmax(QK^T)V
for (int d = tid; d < d_head; d += block_size) {
int dst_index = d + s * d_head + head * head_size;
int value_offset = d * seq_len + head * head_size;
dst[dst_index] = 0.0f;
for(int ic = 0; ic < seq_len; ic++) {
dst[dst_index] += v[value_offset + ic] * S[ic];
for (int d = threadIdx.x; d < head_dim; d += block_size) {
const int dst_index = d + s * head_dim + head * head_size;
const int value_offset = d * seq_len + head * head_size;

float temp = 0.0f;
#pragma unroll
for(int ic = 0; ic < k_seq_len;ic++) {
if(ic >= seq_len) {
break;
}
temp += v[value_offset + ic] * S[ic];
}
kqv[dst_index] = temp;
}
}

Expand Down Expand Up @@ -7462,7 +7497,7 @@ static void im2col_f32_f16_cuda(const float* x, half* dst,
static void flash_attn_f32_cuda(const float* q, const float* k,const float* v, float* dst, float kq_scale, const int d_head, const int seq_len, const int num_heads, cudaStream_t stream) {
int sram_memory_size = seq_len*sizeof(float) + WARP_SIZE * sizeof(float);
int num_blocks = num_heads * seq_len;
flash_attn_f32<CUDA_FLASH_ATTENTION_BLOCK_SIZE><<<num_blocks, CUDA_FLASH_ATTENTION_BLOCK_SIZE, sram_memory_size, stream>>>(
flash_attn_f32<CUDA_FLASH_ATTENTION_BLOCK_SIZE, 1024><<<num_blocks, CUDA_FLASH_ATTENTION_BLOCK_SIZE, sram_memory_size, stream>>>(
q, k, v, dst, kq_scale, d_head, seq_len, num_heads);
}

Expand Down
26 changes: 20 additions & 6 deletions tests/test-flash-attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ struct test_model {
struct ggml_tensor * k;
struct ggml_tensor * v;
ggml_backend_t backend = NULL;
ggml_backend_buffer_t buffer;
struct ggml_context * ctx;
ggml_backend_buffer_t buffer = NULL;
struct ggml_context * ctx = NULL;
bool naive_attn = false;
};

static std::vector<float> tensor_to_float(const ggml_tensor * t) {
Expand Down Expand Up @@ -216,8 +217,16 @@ struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * a

struct ggml_cgraph * gf = ggml_new_graph(ctx0);

struct ggml_tensor* result = ggml_flash_attn(ctx0, model.q, model.k, model.v, false);
ggml_build_forward_expand(gf, result);
if(!model.naive_attn) {
struct ggml_tensor* result = ggml_flash_attn(ctx0, model.q, model.k, model.v, false);
ggml_build_forward_expand(gf, result);
} else {
struct ggml_tensor* kq = ggml_mul_mat(ctx0, model.k, model.q);
kq = ggml_scale_inplace(ctx0, kq, 1.0f / sqrtf((float)model.q->ne[0]));
kq = ggml_soft_max(ctx0, kq);
kq = ggml_mul_mat(ctx0, model.v, kq);
ggml_build_forward_expand(gf, kq);
}

// delete the temporally context used to build the graph
ggml_free(ctx0);
Expand Down Expand Up @@ -330,15 +339,18 @@ struct ggml_tensor* compute_graph(const test_model & model, ggml_backend_t backe
int main(int argc, char ** argv)
{
bool compare_backend = false;
test_model model;
for (int i = 1; i < argc; i++) {
if (strcmp(argv[i], "comp") == 0) {
compare_backend = true;
} else if (strcmp(argv[i], "naive") == 0) {
model.naive_attn = true;
}
}

ggml_time_init();

test_model model;

load_model(model, true);

ggml_backend_buffer_t buf_compute; // for compute
Expand All @@ -359,9 +371,11 @@ int main(int argc, char ** argv)
}

ggml_backend_t backend_cpu = ggml_backend_cpu_init();

uint64_t compute_time_us__ = ggml_time_us();
struct ggml_tensor * result = compute_graph(model, backend_cpu, allocr, compare_backend);
if(!compare_backend) {
ggml_backend_synchronize(model.backend);
printf("computing time: %.4f ms\n", (ggml_time_us() - compute_time_us__) / 1000.0);
float* data = new float[ggml_nelements(result)];

ggml_backend_tensor_get(result, data, 0, ggml_nbytes(result));
Expand Down
0