From f7bcfb0566a9703f2cfdb5ef6825f4cfb996e6f3 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Wed, 17 Jan 2024 16:38:28 -0500 Subject: [PATCH 01/58] cuda: add flash attention + test --- ggml-cuda.cu | 142 +++++++++++- tests/CMakeLists.txt | 2 + tests/test-flash-attention.cpp | 383 +++++++++++++++++++++++++++++++++ 3 files changed, 526 insertions(+), 1 deletion(-) create mode 100644 tests/test-flash-attention.cpp diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 568c411afd3ee..bb65ca642f2e7 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -5987,6 +5987,88 @@ static __global__ void im2col_f32_f16( } } +#define CUDA_FLASH_ATTENTION_BLOCK_SIZE 256 + +template +static __global__ void flash_attn_f32(const float* q, const float* k,const float* v, float* dst, float kq_scale, + int d_head, int seq_len, int num_heads) { + const int head = blockIdx.x / seq_len; + const int head_size = d_head * seq_len; + const int s = blockIdx.x % seq_len; + const int tid = threadIdx.x; + + extern __shared__ char work_data[]; + float* S = (float*)work_data; // theorical sequent length: 12848, due memory per block limit + float* warp_data = (float*)(work_data + seq_len * sizeof(float)); + + // QK^T + for(int is = tid; is < seq_len; is += block_size) { + S[is] = 0.0f; + int key_offset = is * d_head + head * head_size; + int query_offset = s * d_head + head * head_size; + for(int d = 0; d < d_head; d++) { + S[is] += k[key_offset + d] * q[query_offset + d]; + } + S[is] *= kq_scale; + } + + __syncthreads(); + + float max_val = -INFINITY; + // get the max + for(int is = tid; is < seq_len; is += block_size) { + max_val = fmaxf(max_val , S[is]); + } + + max_val = warp_reduce_max(max_val); + { // get max from all threads + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + warp_data[warp_id] = max_val; + } + __syncthreads(); + max_val = warp_data[lane_id]; + max_val = warp_reduce_max(max_val); + } + + // softmax(QK^T) + float sum = 0.0f; + for(int is = tid; is < seq_len;is += block_size) { + const float val = expf(S[is] - max_val); + S[is] = val; + sum += val; + } + + sum = warp_reduce_sum(sum); + { // sum partials + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + warp_data[warp_id] = sum; + } + __syncthreads(); + sum = warp_data[lane_id]; + sum = warp_reduce_sum(sum); + } + + float inv_sum = 1.0f / sum; + for(int is = tid; is < seq_len; is += block_size) { + S[is] *= inv_sum; + } + + __syncthreads(); + // softmax(QK^T)V + for (int d = tid; d < d_head; d += block_size) { + int dst_index = d + s * d_head + head * head_size; + int value_offset = d * seq_len + head * head_size; + dst[dst_index] = 0.0f; + for(int ic = 0; ic < seq_len; ic++) { + dst[dst_index] += v[value_offset + ic] * S[ic]; + } + } +} + template static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { @@ -7377,6 +7459,13 @@ static void im2col_f32_f16_cuda(const float* x, half* dst, im2col_f32_f16<<>>(x, dst, offset_delta, IW, IH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1); } +static void flash_attn_f32_cuda(const float* q, const float* k,const float* v, float* dst, float kq_scale, const int d_head, const int seq_len, const int num_heads, cudaStream_t stream) { + int sram_memory_size = seq_len*sizeof(float) + WARP_SIZE * sizeof(float); + int num_blocks = num_heads * seq_len; + flash_attn_f32<<>>( + q, k, v, dst, kq_scale, d_head, seq_len, num_heads); +} + // buffer pool for cuda #define MAX_CUDA_BUFFERS 256 @@ -9900,6 +9989,51 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s } } +inline void ggml_cuda_flash_attn(const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV) { + GGML_ASSERT(Q->type == GGML_TYPE_F32); + GGML_ASSERT(K->type == GGML_TYPE_F32); + GGML_ASSERT(V->type == GGML_TYPE_F32); + GGML_ASSERT(KQV->type == GGML_TYPE_F32); + + GGML_ASSERT(Q->backend == GGML_BACKEND_GPU); + GGML_ASSERT(K->backend == GGML_BACKEND_GPU); + GGML_ASSERT(V->backend == GGML_BACKEND_GPU); + GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU); + + ggml_cuda_set_device(g_main_device); + const cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; + + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) Q->extra; + ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) K->extra; + ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) V->extra; + ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) KQV->extra; + + const int64_t d_head = Q->ne[0]; + const int64_t sequence_length = Q->ne[1]; + const int64_t num_heads = Q->ne[2]; + + GGML_ASSERT(Q->ne[0] == d_head); + GGML_ASSERT(K->ne[0] == d_head); + GGML_ASSERT(V->ne[1] == d_head); + + GGML_ASSERT(Q->ne[1] == sequence_length); + GGML_ASSERT(K->ne[1] == sequence_length); + GGML_ASSERT(V->ne[0] == sequence_length); + + GGML_ASSERT(Q->ne[2] == num_heads); + GGML_ASSERT(K->ne[2] == num_heads); + GGML_ASSERT(V->ne[2] == num_heads); + + float KQ_scale = 1.0f / sqrtf((float)d_head); + + flash_attn_f32_cuda( + (float *) src0_extra->data_device[g_main_device], // Query + (float *) src1_extra->data_device[g_main_device], // Key + (float *) src2_extra->data_device[g_main_device], // Value + (float *) dst_extra->data_device[g_main_device], // dst + KQ_scale, d_head, sequence_length, num_heads, main_stream); +} + static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale); } @@ -10168,6 +10302,8 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st case GGML_OP_ARGSORT: func = ggml_cuda_argsort; break; + case GGML_OP_FLASH_ATTN: + break; default: return false; } @@ -10182,7 +10318,11 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return true; } - func(tensor->src[0], tensor->src[1], tensor); + if(tensor->op == GGML_OP_FLASH_ATTN) { + ggml_cuda_flash_attn(tensor->src[0], tensor->src[1], tensor->src[2], tensor); + } else { + func(tensor->src[0], tensor->src[1], tensor); + } return true; } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 7c932240de82d..bc5649989c4d9 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -52,6 +52,8 @@ llama_build_and_test_executable(test-backend-ops.cpp) llama_build_and_test_executable(test-rope.cpp) +llama_build_executable(test-flash-attention.cpp) + # dummy executable - not installed get_filename_component(TEST_TARGET test-c.c NAME_WE) add_executable(${TEST_TARGET} test-c.c) diff --git a/tests/test-flash-attention.cpp b/tests/test-flash-attention.cpp new file mode 100644 index 0000000000000..c99ad719d6ca1 --- /dev/null +++ b/tests/test-flash-attention.cpp @@ -0,0 +1,383 @@ +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" + +#define GGML_USE_CUBLAS + +#ifdef GGML_USE_CUBLAS +#include "ggml-cuda.h" +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +struct test_model { + struct ggml_tensor * q; + struct ggml_tensor * k; + struct ggml_tensor * v; + ggml_backend_t backend = NULL; + ggml_backend_buffer_t buffer; + struct ggml_context * ctx; +}; + +static std::vector tensor_to_float(const ggml_tensor * t) { + std::vector tv; + tv.reserve(ggml_nelements(t)); + + std::vector buf(ggml_nbytes(t)); + ggml_backend_tensor_get(t, buf.data(), 0, ggml_nbytes(t)); + + ggml_type_traits_t tt = ggml_internal_get_type_traits(t->type); + size_t bs = ggml_blck_size(t->type); + std::vector vq(ggml_blck_size(t->type)); + bool quantized = ggml_is_quantized(t->type); + + // access elements by index to avoid gaps in views + for (int64_t i3 = 0; i3 < t->ne[3]; i3++) { + for (int64_t i2 = 0; i2 < t->ne[2]; i2++) { + for (int64_t i1 = 0; i1 < t->ne[1]; i1++) { + for (int64_t i0 = 0; i0 < t->ne[0]; i0 += bs) { + size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0/bs*t->nb[0]; + if (t->type == GGML_TYPE_F16) { + tv.push_back(ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i])); + } else if (t->type == GGML_TYPE_F32) { + tv.push_back(*(float *) &buf[i]); + } else if (t->type == GGML_TYPE_I32) { + tv.push_back((float)*(int32_t *) &buf[i]); + } else if (t->type == GGML_TYPE_I16) { + tv.push_back((float)*(int16_t *) &buf[i]); + } else if (t->type == GGML_TYPE_I8) { + tv.push_back((float)*(int8_t *) &buf[i]); + } else if (quantized) { + std::vector vq(ggml_blck_size(t->type)); + tt.to_float(&buf[i], vq.data(), ggml_blck_size(t->type)); + tv.insert(tv.end(), vq.begin(), vq.end()); + } else { + GGML_ASSERT(false); + } + } + } + } + } + + return tv; +} + +// accept FLT_MAX as infinity +static bool isinf_or_max(float f) { + return std::isinf(f) || f == FLT_MAX || f == -FLT_MAX; +} + +// normalized mean squared error = mse(a, b) / mse(a, 0) +static double nmse(const float * a, const float * b, size_t n) { + double mse_a_b = 0.0; + double mse_a_0 = 0.0; + + for (size_t i = 0; i < n; i++) { + float a_i = a[i]; + float b_i = b[i]; + + mse_a_b += (a_i - b_i) * (a_i - b_i); + mse_a_0 += a_i * a_i; + } + + return mse_a_b / mse_a_0; +} + +void ggml_tensor_set_f32(struct ggml_tensor* tensor, float value, int l, int k = 0, int j = 0, int i = 0) { + GGML_ASSERT(tensor->nb[0] == sizeof(float)); + *(float*)((char*)(tensor->data) + i * tensor->nb[3] + j * tensor->nb[2] + k * tensor->nb[1] + l * tensor->nb[0]) = value; +} + +float ggml_tensor_get_f32(const ggml_tensor* tensor, int l, int k = 0, int j = 0, int i = 0) { + GGML_ASSERT(tensor->nb[0] == sizeof(float)); + return *(float*)((char*)(tensor->data) + i * tensor->nb[3] + j * tensor->nb[2] + k * tensor->nb[1] + l * tensor->nb[0]); +} + +void load_model(test_model & model, bool use_gpu = false) { + float Query[30] = { // [3, 4, 2] + // z0 + 2, 4, 2, + 4, 2, 1, + 4, 1, 3, + 4, 2, 2, + + // z1 + 2, 1, 1, + 4, 2, 1, + 1, 1, 3, + 4, 2, 1 + }; + + float Key[24] = { // [3, 4, 2] + // z0 + 2, 4, 2, + 4, 2, 1, + 4, 2, 3, + 1, 2, 1, + + // z1 + 3, 1, 3, + 4, 2, 1, + 1, 1, 2, + 4, 3, 1 + }; + + float Value[24] = { // [4, 3, 2] + // z0 + 2, 4, 2, 1, + 2, 1, 4, 2, + 1, 4, 2, 3, + + // z1 + 1, 4, 2, 1, + 2, 1, 1, 2, + 1, 4, 3, 3, + }; + + size_t buffer_size = 0; + { + buffer_size += 30 * ggml_type_sizef(GGML_TYPE_F32); // tensor q + buffer_size += 24 * ggml_type_sizef(GGML_TYPE_F32); // tensor k + buffer_size += 24 * ggml_type_sizef(GGML_TYPE_F32); // tensor v + buffer_size += 1024; + } + + printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor)); + printf("%s: backend buffer size = %0.2f MB\n", __func__, (buffer_size/ 1024.f/ 1024.f)); + + int num_tensors = 3; + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead() * num_tensors, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + // initialize the backend +#ifdef GGML_USE_CUBLAS + if (use_gpu) { + fprintf(stderr, "%s: using CUDA backend\n", __func__); + model.backend = ggml_backend_cuda_init(0); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); + } + } +#endif + + if(!model.backend) { + // fallback to CPU backend + model.backend = ggml_backend_cpu_init(); + } + + model.buffer = ggml_backend_alloc_buffer(model.backend, buffer_size); + + // create context + model.ctx = ggml_init(params); + + // create tensors + model.q = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, 3, 4, 2); + model.k = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, 3, 4, 2); + model.v = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, 4, 3, 2); + + // create a allocator + ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer); + + // alloc memory + ggml_allocr_alloc(alloc, model.q); + ggml_allocr_alloc(alloc, model.k); + ggml_allocr_alloc(alloc, model.v); + + ggml_backend_tensor_set(model.q, Query, 0, ggml_nbytes(model.q)); + ggml_backend_tensor_set(model.k, Key, 0, ggml_nbytes(model.k)); + ggml_backend_tensor_set(model.v, Value, 0, ggml_nbytes(model.v)); + + ggml_allocr_free(alloc); +} + +struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * allocr) { + static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); + static std::vector buf(buf_size); + + struct ggml_init_params params0 = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf.data(), + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_allocr_alloc_graph() + }; + + // create a temporally context to build the graph + struct ggml_context * ctx0 = ggml_init(params0); + + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor* result = ggml_flash_attn(ctx0, model.q, model.k, model.v, false); + ggml_build_forward_expand(gf, result); + + // delete the temporally context used to build the graph + ggml_free(ctx0); + return gf; +} + +struct ggml_tensor* compute_graph(const test_model & model, ggml_backend_t backend_cpu, struct ggml_allocr * allocr, bool compare_backends) { + // reset the allocator to free all the memory allocated during the previous inference + ggml_allocr_reset(allocr); + + struct ggml_cgraph * gf = build_graph(model, allocr); + + // allocate tensors + ggml_allocr_alloc_graph(allocr, gf); + int n_threads = 1; + + if (ggml_backend_is_cpu(model.backend)) { + ggml_backend_cpu_set_n_threads(model.backend, n_threads); + } + + if(!compare_backends) { + ggml_backend_graph_compute(model.backend, gf); + + // in this case, the output tensor is the last one in the graph + return gf->nodes[gf->n_nodes - 1]; + } + + struct callback_userdata { + bool ok; + double max_err; + ggml_backend_t backend1; + ggml_backend_t backend2; + }; + + callback_userdata ud { + true, + 1e-7, + model.backend, + backend_cpu + }; + + auto callback = [](int index, ggml_tensor * t1, ggml_tensor * t2, void * user_data) -> bool { + callback_userdata * ud = (callback_userdata *) user_data; + const char * bn1 = ggml_backend_name(ud->backend1); + const char * bn2 = ggml_backend_name(ud->backend2); + + if (t1->op == GGML_OP_NONE) { + // sentinels must be unchanged + std::vector t1_data(ggml_nbytes(t1)); + std::vector t2_data(ggml_nbytes(t2)); + ggml_backend_tensor_get(t1, t1_data.data(), 0, ggml_nbytes(t1)); + ggml_backend_tensor_get(t2, t2_data.data(), 0, ggml_nbytes(t2)); + + if (memcmp(t1_data.data(), t2_data.data(), ggml_nbytes(t1)) != 0) { + printf("sentinel mismatch: %s ", t1->name); + ud->ok = false; + return true; + } + } + + std::vector f1 = tensor_to_float(t1); + std::vector f2 = tensor_to_float(t2); + + for (size_t i = 0; i < f1.size(); i++) { + // check for nans + if (std::isnan(f1[i]) || std::isnan(f2[i])) { + printf("[%s] NaN at index %zu (%s=%f %s=%f) ", ggml_op_desc(t1), i, bn1, f1[i], bn2, f2[i]); + ud->ok = false; + return true; + } + // check for infs: both must be inf of the same sign, or both must be finite + if (isinf_or_max(f1[i]) || isinf_or_max(f2[i])) { + if (isinf_or_max(f1[i]) && isinf_or_max(f2[i])) { + if (std::signbit(f1[i]) != std::signbit(f2[i])) { + printf("[%s] inf sign mismatch: %s=%f %s=%f ", ggml_op_desc(t1), bn1, f1[i], bn2, f2[i]); + ud->ok = false; + return true; + } + } else { + printf("[%s] inf mismatch: %s=%f %s=%f ", ggml_op_desc(t1), bn1, f1[i], bn2, f2[i]); + ud->ok = false; + return true; + } + } + } + + double err = nmse(f1.data(), f2.data(), f1.size()); + if (err > ud->max_err) { + printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err); + ud->ok = false; + } + + return true; + + GGML_UNUSED(index); + }; + + printf("\nTesting Flash Attention - comparing backends: "); + + const bool cmp_ok = ggml_backend_compare_graph_backend(model.backend, backend_cpu, gf, callback, &ud); + if (ud.ok && cmp_ok) { + printf("\033[1;32mOK\033[0m\n"); + return NULL; + } + + printf("\033[1;31mFAIL\033[0m\n"); + return NULL; +} + +int main(int argc, char ** argv) +{ + bool compare_backend = false; + for (int i = 1; i < argc; i++) { + if (strcmp(argv[i], "comp") == 0) { + compare_backend = true; + } + } + + ggml_time_init(); + + test_model model; + load_model(model, true); + + ggml_backend_buffer_t buf_compute; // for compute + struct ggml_allocr * allocr = NULL; + + { + allocr = ggml_allocr_new_measure_from_backend(model.backend); + + //create the worst case graph for memory usage estimation + struct ggml_cgraph * gf = build_graph(model, allocr); + size_t mem_size = ggml_allocr_alloc_graph(allocr, gf); + ggml_allocr_free(allocr); + + // compute the required memory + buf_compute = ggml_backend_alloc_buffer(model.backend, mem_size); + allocr = ggml_allocr_new_from_buffer(buf_compute); + fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); + } + + ggml_backend_t backend_cpu = ggml_backend_cpu_init(); + + struct ggml_tensor * result = compute_graph(model, backend_cpu, allocr, compare_backend); + if(!compare_backend) { + float* data = new float[ggml_nelements(result)]; + + ggml_backend_tensor_get(result, data, 0, ggml_nbytes(result)); + printf("\nPerforming test:\n"); + + for(int i = 0; i < ggml_nelements(result); i ++) { + if(i > 0 && (i % result->ne[0] == 0)) { + printf("\n"); + } + printf("%2.6f ", data[i]); + } + } + + ggml_free(model.ctx); + + ggml_backend_buffer_free(model.buffer); + ggml_backend_buffer_free(buf_compute); + ggml_backend_free(model.backend); + return 0; +} From e53de2866ad973a12d5b60813eaf117791a20904 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Thu, 18 Jan 2024 11:27:07 -0500 Subject: [PATCH 02/58] fix compilation --- tests/test-flash-attention.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test-flash-attention.cpp b/tests/test-flash-attention.cpp index c99ad719d6ca1..fb5e2a8bc6f98 100644 --- a/tests/test-flash-attention.cpp +++ b/tests/test-flash-attention.cpp @@ -16,6 +16,7 @@ #include #include #include +#include struct test_model { struct ggml_tensor * q; From a1c004ef2e056cdeffcd47aaac196883bb123a3a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 18 Jan 2024 17:42:55 +0200 Subject: [PATCH 03/58] ggml : add ggml_flash_attn_ext API --- ggml-metal.m | 50 +++++++ ggml-metal.metal | 29 ++++ ggml.c | 298 ++++++++++++++++++++++++++++++++++++- ggml.h | 9 ++ llama.cpp | 80 +++++----- tests/test-backend-ops.cpp | 28 ++++ 6 files changed, 456 insertions(+), 38 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 912ddc83f7d9c..6d88d5c36a8ad 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -147,6 +147,7 @@ GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, @@ -511,6 +512,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16, flash_attn_ext_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); @@ -665,6 +667,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_OP_PAD: case GGML_OP_ARGSORT: case GGML_OP_LEAKY_RELU: + case GGML_OP_FLASH_ATTN_EXT: return true; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: @@ -2161,6 +2164,53 @@ static bool ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_OP_FLASH_ATTN_EXT: + { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + + struct ggml_tensor * src2 = gf->nodes[i]->src[2]; + struct ggml_tensor * src3 = gf->nodes[i]->src[3]; + + size_t offs_src2 = 0; + size_t offs_src3 = 0; + + id id_src2 = src2 ? ggml_metal_get_buffer(ctx, src2, &offs_src2) : nil; + id id_src3 = src3 ? ggml_metal_get_buffer(ctx, src3, &offs_src3) : nil; + + float scale; + memcpy(&scale, dst->op_params, sizeof(float)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16].pipeline; + + // TODO: extend if necessary + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:15]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:16]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:20]; + [encoder setBytes:&scale length:sizeof( float) atIndex:21]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case GGML_OP_DUP: case GGML_OP_CPY: case GGML_OP_CONT: diff --git a/ggml-metal.metal b/ggml-metal.metal index 029578dc54dbd..b79a1ba5634a7 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1959,6 +1959,35 @@ kernel void kernel_leaky_relu_f32( dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; } +kernel void kernel_flash_attn_ext_f16( + device const half * q, + device const half * k, + device const half * v, + device const half * mask, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant float & scale, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + // TODO: implement +} + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, diff --git a/ggml.c b/ggml.c index cbf2d4bddddb8..e01d938ceb681 100644 --- a/ggml.c +++ b/ggml.c @@ -1650,6 +1650,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "LEAKY_RELU", "FLASH_ATTN", + "FLASH_ATTN_EXT", "FLASH_FF", "FLASH_ATTN_BACK", "WIN_PART", @@ -1674,7 +1675,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72"); +static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1736,6 +1737,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "leaky_relu(x)", "flash_attn(x)", + "flash_attn_ext(x)", "flash_ff(x)", "flash_attn_back(x)", "win_part(x)", @@ -1760,7 +1762,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72"); +static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5678,6 +5680,46 @@ struct ggml_tensor * ggml_flash_attn( return result; } +// ggml_flash_attn_ext + +struct ggml_tensor * ggml_flash_attn_ext( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * mask, + float scale) { + GGML_ASSERT(ggml_can_mul_mat(k, q)); + // TODO: check if vT can be multiplied by (k*qT) + if (mask) { + GGML_ASSERT(ggml_is_contiguous(mask)); + GGML_ASSERT(mask->ne[2] == 1); + GGML_ASSERT(mask->ne[3] == 1); + //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); + } + + bool is_node = false; + + if (q->grad || k->grad || v->grad) { + is_node = true; + } + + //struct ggml_tensor * result = ggml_dup_tensor(ctx, q); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne); + + float params[] = { scale }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_FLASH_ATTN_EXT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = q; + result->src[1] = k; + result->src[2] = v; + result->src[3] = mask; + + return result; +} + // ggml_flash_ff struct ggml_tensor * ggml_flash_ff( @@ -13212,6 +13254,251 @@ static void ggml_compute_forward_flash_attn( } } +// ggml_compute_forward_flash_attn_ext + +static void ggml_compute_forward_flash_attn_ext_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const struct ggml_tensor * mask, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t D = neq0; + const int64_t N = neq1; + const int64_t P = nek1 - N; + const int64_t M = P + N; + + const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); + + GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne1 == N); + GGML_ASSERT(P >= 0); + + GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t)); + + GGML_ASSERT(neq0 == D); + GGML_ASSERT(nek0 == D); + GGML_ASSERT(nev1 == D); + + GGML_ASSERT(neq1 == N); + GGML_ASSERT(nek1 == N + P); + GGML_ASSERT(nev1 == D); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // parallelize by q rows using ggml_vec_dot_f32 + + // total rows in q + const int nr = neq1*neq2*neq3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float scale = 1.0f; + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + + //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int iq3 = ir/(neq2*neq1); + const int iq2 = (ir - iq3*neq2*neq1)/neq1; + const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + + float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32); + + for (int i = M; i < Mup; ++i) { + S[i] = -INFINITY; + } + + if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) { + for (int64_t ic = 0; ic < nek1; ++ic) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2 % nek2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f16(neq0, + S + i1, + (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + } else { + for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2 % nek2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f16_unroll(neq0, nbk1, + S + i1, + ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + } + + // scale + ggml_vec_scale_f32(nek1, S, scale); + + if (mask) { + const float * mp = (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]); + ggml_vec_acc_f32(M, S, mp); + } + + // softmax + // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero. + // dont forget to set their S values to zero + { + float max = -INFINITY; + ggml_vec_max_f32(M, &max, S); + + ggml_float sum = 0.0; + { +#ifdef GGML_SOFT_MAX_ACCELERATE + max = -max; + vDSP_vsadd(S, 1, &max, S, 1, Mup); + vvexpf(S, S, &Mup); + ggml_vec_sum_f32(Mup, &sum, S); +#else + uint16_t scvt[GGML_SOFT_MAX_UNROLL]; + ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; + + for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { + float * SS = S + i; + + for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { + if (SS[j] == -INFINITY) { + SS[j] = 0.0f; + } else { + ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); + memcpy(&scvt[j], &s, sizeof(uint16_t)); + const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]); + sump[j] += (ggml_float)val; + SS[j] = val; + } + } + } + + for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { + sum += sump[i]; + } +#endif + } + + assert(sum > 0.0); + + sum = 1.0/sum; + ggml_vec_scale_f32(M, S, sum); + +#ifndef NDEBUG + for (int i = 0; i < M; ++i) { + assert(!isnan(S[i])); + assert(!isinf(S[i])); + } +#endif + } + + ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup); + + for (int64_t i = 0; i < M; i++) { + S16[i] = GGML_FP32_TO_FP16(S[i]); + } + + // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16). + if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) { + for (int64_t ic = 0; ic < nev1; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // v indices + const int iv2 = iq2 % nev2; + const int iv3 = iq3; + + ggml_vec_dot_f16(nev0, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), + S16); + } + } else { + for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // v indices + const int iv2 = iq2 % nev2; + const int iv3 = iq3; + + ggml_vec_dot_f16_unroll(nev0, nbv1, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), + S16); + } + } + } +} + +static void ggml_compute_forward_flash_attn_ext( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const struct ggml_tensor * mask, + struct ggml_tensor * dst) { + switch (q->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_flash_ff static void ggml_compute_forward_flash_ff_f16( @@ -14717,6 +15004,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm const bool masked = t != 0; ggml_compute_forward_flash_attn(params, tensor->src[0], tensor->src[1], tensor->src[2], masked, tensor); } break; + case GGML_OP_FLASH_ATTN_EXT: + { + ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); + } break; case GGML_OP_FLASH_FF: { ggml_compute_forward_flash_ff(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor); @@ -15713,6 +16004,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor GGML_ASSERT(false); // TODO: not implemented } break; case GGML_OP_FLASH_ATTN: + case GGML_OP_FLASH_ATTN_EXT: { struct ggml_tensor * flash_grad = NULL; if (src0->grad || src1->grad || tensor->src[2]->grad) { @@ -16438,6 +16730,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { n_tasks = n_threads; } break; case GGML_OP_FLASH_ATTN: + case GGML_OP_FLASH_ATTN_EXT: { n_tasks = n_threads; } break; @@ -16769,6 +17062,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; } break; case GGML_OP_FLASH_ATTN: + case GGML_OP_FLASH_ATTN_EXT: { const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL); diff --git a/ggml.h b/ggml.h index de8162b8135f3..d76fe9d5c48c9 100644 --- a/ggml.h +++ b/ggml.h @@ -452,6 +452,7 @@ extern "C" { GGML_OP_LEAKY_RELU, GGML_OP_FLASH_ATTN, + GGML_OP_FLASH_ATTN_EXT, GGML_OP_FLASH_FF, GGML_OP_FLASH_ATTN_BACK, GGML_OP_WIN_PART, @@ -1619,6 +1620,14 @@ extern "C" { struct ggml_tensor * v, bool masked); + GGML_API struct ggml_tensor * ggml_flash_attn_ext( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * mask, + float scale); + GGML_API struct ggml_tensor * ggml_flash_attn_back( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/llama.cpp b/llama.cpp index d28382f7d47b7..cec23c23f1dce 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4205,38 +4205,6 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(k, "k", il); - struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); - cb(kq, "kq", il); - - if (model.arch == LLM_ARCH_PHI2) { - // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs - // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 - ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - } - - if (max_alibi_bias > 0.0f) { - // temporary branch until we figure out how to handle ggml_alibi through ggml_add - kq = ggml_scale(ctx, kq, kq_scale); - cb(kq, "kq_scaled", il); - - if (max_alibi_bias > 0.0f) { - // TODO: n_head or n_head_kv - // TODO: K-shift is likely not working - // TODO: change to ggml_add - kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias); - cb(kq, "kq_scaled_alibi", il); - } - - kq = ggml_add(ctx, kq, kq_mask); - cb(kq, "kq_masked", il); - - kq = ggml_soft_max(ctx, kq); - cb(kq, "kq_soft_max", il); - } else { - kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale); - cb(kq, "kq_soft_max_ext", il); - } - // split cached v into n_head heads struct ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il], @@ -4246,8 +4214,49 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(v, "v", il); - struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); - cb(kqv, "kqv", il); + // TODO: determine if we can use flash attention + const bool supports_flash_attn = true; + + struct ggml_tensor * kqv; + + if (supports_flash_attn) { + kqv = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); + } else { + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); + cb(kq, "kq", il); + + if (model.arch == LLM_ARCH_PHI2) { + // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs + // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + } + + if (max_alibi_bias > 0.0f) { + // temporary branch until we figure out how to handle ggml_alibi through ggml_add + kq = ggml_scale(ctx, kq, kq_scale); + cb(kq, "kq_scaled", il); + + if (max_alibi_bias > 0.0f) { + // TODO: n_head or n_head_kv + // TODO: K-shift is likely not working + // TODO: change to ggml_add + kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias); + cb(kq, "kq_scaled_alibi", il); + } + + kq = ggml_add(ctx, kq, kq_mask); + cb(kq, "kq_masked", il); + + kq = ggml_soft_max(ctx, kq); + cb(kq, "kq_soft_max", il); + } else { + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale); + cb(kq, "kq_soft_max_ext", il); + } + + kqv = ggml_mul_mat(ctx, v, kq); + cb(kqv, "kqv", il); + } struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); cb(kqv_merged, "kqv_merged", il); @@ -9490,8 +9499,7 @@ struct llama_context * llama_new_context_with_model( } ctx->backends.push_back(ctx->backend_cpu); - if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, - cparams.n_ctx, cparams.offload_kqv)) { + if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, cparams.n_ctx, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 55ce14e0d902c..5693c2197c7c5 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1384,6 +1384,32 @@ struct test_leaky_relu : public test_case { } }; +// GGML_OP_FLASH_ATTN_EXT +struct test_flash_attn_ext : public test_case { + const ggml_type typeq; + const int64_t hs; // head size + const int64_t nh; // num heads + const int64_t kv; // kv size + const int64_t nt; // tokens + + std::string vars() override { + return VARS_TO_STR5(typeq, hs, nh, kv, nt); + } + + test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, + int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nt = 8) + : typeq(typeq), hs(hs), nh(nh), kv(kv), nt(nt) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nt, nh, 1); + ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); + ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); + ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nt, 1, 1); + ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs)); + return out; + } +}; + // Mixtral MOE struct test_moe : public test_case { const int n_experts; @@ -1650,6 +1676,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); + test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 96, 8)); + #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer test_cases.emplace_back(new test_moe(8, 2, 1, 4096, 8*1024)); From fa7ebcca993ec0d47f6ed6a47a8d5ac4f7407262 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 Jan 2024 20:06:26 +0200 Subject: [PATCH 04/58] ggml : fix GQA support in ggml_flash_attn_ext --- ggml-metal.metal | 8 ++++---- ggml.c | 23 +++++++++++++++-------- llama.cpp | 4 ++++ 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index b79a1ba5634a7..28847794cb5d8 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1960,10 +1960,10 @@ kernel void kernel_leaky_relu_f32( } kernel void kernel_flash_attn_ext_f16( - device const half * q, - device const half * k, - device const half * v, - device const half * mask, + device const half * q, + device const half * k, + device const half * v, + device const float * mask, device float * dst, constant int64_t & ne00, constant int64_t & ne01, diff --git a/ggml.c b/ggml.c index e01d938ceb681..9cf4784ce4759 100644 --- a/ggml.c +++ b/ggml.c @@ -13307,6 +13307,13 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb2 <= nb3); + // broadcast factors + const int64_t rk2 = neq2/nek2; + const int64_t rk3 = neq3/nek3; + + const int64_t rv2 = neq2/nev2; + const int64_t rv3 = neq3/nev3; + if (params->type == GGML_TASK_INIT) { return; } @@ -13347,8 +13354,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) { for (int64_t ic = 0; ic < nek1; ++ic) { // k indices - const int ik3 = iq3; - const int ik2 = iq2 % nek2; + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; const int ik1 = ic; // S indices @@ -13362,8 +13369,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( } else { for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { // k indices - const int ik3 = iq3; - const int ik2 = iq2 % nek2; + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; const int ik1 = ic; // S indices @@ -13452,8 +13459,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int i3 = iq3; // v indices - const int iv2 = iq2 % nev2; - const int iv3 = iq3; + const int iv2 = iq2 / rv2; + const int iv3 = iq3 / rv3; ggml_vec_dot_f16(nev0, (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), @@ -13468,8 +13475,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int i3 = iq3; // v indices - const int iv2 = iq2 % nev2; - const int iv3 = iq3; + const int iv2 = iq2 / rv2; + const int iv3 = iq3 / rv3; ggml_vec_dot_f16_unroll(nev0, nbv1, (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), diff --git a/llama.cpp b/llama.cpp index cec23c23f1dce..d4bebe5203e9a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4220,6 +4220,10 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * kqv; if (supports_flash_attn) { + //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); + //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); + //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); + //printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]); kqv = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); } else { struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); From fded2e6a11bb600e04fec8714ab9165bda7724f8 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Fri, 19 Jan 2024 20:18:18 -0500 Subject: [PATCH 05/58] apply suggestions --- ggml-cuda.cu | 91 +++++++++++++++++++++++----------- tests/test-flash-attention.cpp | 26 +++++++--- 2 files changed, 83 insertions(+), 34 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index bafb2ff1c0ae1..aeb07c9641579 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -5989,38 +5989,55 @@ static __global__ void im2col_f32_f16( #define CUDA_FLASH_ATTENTION_BLOCK_SIZE 256 -template -static __global__ void flash_attn_f32(const float* q, const float* k,const float* v, float* dst, float kq_scale, - int d_head, int seq_len, int num_heads) { +template +static __global__ void flash_attn_f32( + const float* __restrict__ q, + const float* __restrict__ k, + const float* __restrict__ v, + float* __restrict__ kqv, + float kq_scale, + int head_dim, int seq_len, int num_heads) { const int head = blockIdx.x / seq_len; - const int head_size = d_head * seq_len; + const int head_size = head_dim * seq_len; const int s = blockIdx.x % seq_len; - const int tid = threadIdx.x; - extern __shared__ char work_data[]; - float* S = (float*)work_data; // theorical sequent length: 12848, due memory per block limit - float* warp_data = (float*)(work_data + seq_len * sizeof(float)); + extern __shared__ char shmem__[]; + float* S = (float*)shmem__; + float* warp_data = (float*)(shmem__ + seq_len * sizeof(float)); // QK^T - for(int is = tid; is < seq_len; is += block_size) { + #pragma unroll + for(int is0 = 0; is0 < k_seq_len; is0 += block_size) { + const int is = threadIdx.x + is0; + if(is >= seq_len) { + break; + } + + const int key_offset = is * head_dim + head * head_size; + const int query_offset = s * head_dim + head * head_size; + S[is] = 0.0f; - int key_offset = is * d_head + head * head_size; - int query_offset = s * d_head + head * head_size; - for(int d = 0; d < d_head; d++) { + for(int d = 0; d < head_dim; d++) { S[is] += k[key_offset + d] * q[query_offset + d]; } S[is] *= kq_scale; } - __syncthreads(); float max_val = -INFINITY; // get the max - for(int is = tid; is < seq_len; is += block_size) { + #pragma unroll + for(int is0 = 0; is0 < k_seq_len; is0 += block_size) { + const int is = threadIdx.x + is0; + if(is >= seq_len) { + break; + } + max_val = fmaxf(max_val , S[is]); } max_val = warp_reduce_max(max_val); + { // get max from all threads int warp_id = threadIdx.x / WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE; @@ -6034,14 +6051,20 @@ static __global__ void flash_attn_f32(const float* q, const float* k,const float // softmax(QK^T) float sum = 0.0f; - for(int is = tid; is < seq_len;is += block_size) { - const float val = expf(S[is] - max_val); - S[is] = val; - sum += val; + #pragma unroll + for(int is0 = 0; is0 < k_seq_len; is0 += block_size) { + const int is = threadIdx.x + is0; + if(is >= seq_len) { + break; + } + + S[is] = expf(S[is] - max_val); + sum += S[is]; } + __syncthreads(); sum = warp_reduce_sum(sum); - { // sum partials + { // softmax sum partials int warp_id = threadIdx.x / WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE; if (lane_id == 0) { @@ -6053,19 +6076,31 @@ static __global__ void flash_attn_f32(const float* q, const float* k,const float } float inv_sum = 1.0f / sum; - for(int is = tid; is < seq_len; is += block_size) { + #pragma unroll + for(int is0 = 0; is0 < k_seq_len; is0 += block_size) { + const int is = threadIdx.x + is0; + if(is >= seq_len) { + break; + } + S[is] *= inv_sum; } - __syncthreads(); + // softmax(QK^T)V - for (int d = tid; d < d_head; d += block_size) { - int dst_index = d + s * d_head + head * head_size; - int value_offset = d * seq_len + head * head_size; - dst[dst_index] = 0.0f; - for(int ic = 0; ic < seq_len; ic++) { - dst[dst_index] += v[value_offset + ic] * S[ic]; + for (int d = threadIdx.x; d < head_dim; d += block_size) { + const int dst_index = d + s * head_dim + head * head_size; + const int value_offset = d * seq_len + head * head_size; + + float temp = 0.0f; + #pragma unroll + for(int ic = 0; ic < k_seq_len;ic++) { + if(ic >= seq_len) { + break; + } + temp += v[value_offset + ic] * S[ic]; } + kqv[dst_index] = temp; } } @@ -7462,7 +7497,7 @@ static void im2col_f32_f16_cuda(const float* x, half* dst, static void flash_attn_f32_cuda(const float* q, const float* k,const float* v, float* dst, float kq_scale, const int d_head, const int seq_len, const int num_heads, cudaStream_t stream) { int sram_memory_size = seq_len*sizeof(float) + WARP_SIZE * sizeof(float); int num_blocks = num_heads * seq_len; - flash_attn_f32<<>>( + flash_attn_f32<<>>( q, k, v, dst, kq_scale, d_head, seq_len, num_heads); } diff --git a/tests/test-flash-attention.cpp b/tests/test-flash-attention.cpp index fb5e2a8bc6f98..74167ed86fc84 100644 --- a/tests/test-flash-attention.cpp +++ b/tests/test-flash-attention.cpp @@ -23,8 +23,9 @@ struct test_model { struct ggml_tensor * k; struct ggml_tensor * v; ggml_backend_t backend = NULL; - ggml_backend_buffer_t buffer; - struct ggml_context * ctx; + ggml_backend_buffer_t buffer = NULL; + struct ggml_context * ctx = NULL; + bool naive_attn = false; }; static std::vector tensor_to_float(const ggml_tensor * t) { @@ -216,8 +217,16 @@ struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * a struct ggml_cgraph * gf = ggml_new_graph(ctx0); - struct ggml_tensor* result = ggml_flash_attn(ctx0, model.q, model.k, model.v, false); - ggml_build_forward_expand(gf, result); + if(!model.naive_attn) { + struct ggml_tensor* result = ggml_flash_attn(ctx0, model.q, model.k, model.v, false); + ggml_build_forward_expand(gf, result); + } else { + struct ggml_tensor* kq = ggml_mul_mat(ctx0, model.k, model.q); + kq = ggml_scale_inplace(ctx0, kq, 1.0f / sqrtf((float)model.q->ne[0])); + kq = ggml_soft_max(ctx0, kq); + kq = ggml_mul_mat(ctx0, model.v, kq); + ggml_build_forward_expand(gf, kq); + } // delete the temporally context used to build the graph ggml_free(ctx0); @@ -330,15 +339,18 @@ struct ggml_tensor* compute_graph(const test_model & model, ggml_backend_t backe int main(int argc, char ** argv) { bool compare_backend = false; + test_model model; for (int i = 1; i < argc; i++) { if (strcmp(argv[i], "comp") == 0) { compare_backend = true; + } else if (strcmp(argv[i], "naive") == 0) { + model.naive_attn = true; } } ggml_time_init(); - test_model model; + load_model(model, true); ggml_backend_buffer_t buf_compute; // for compute @@ -359,9 +371,11 @@ int main(int argc, char ** argv) } ggml_backend_t backend_cpu = ggml_backend_cpu_init(); - + uint64_t compute_time_us__ = ggml_time_us(); struct ggml_tensor * result = compute_graph(model, backend_cpu, allocr, compare_backend); if(!compare_backend) { + ggml_backend_synchronize(model.backend); + printf("computing time: %.4f ms\n", (ggml_time_us() - compute_time_us__) / 1000.0); float* data = new float[ggml_nelements(result)]; ggml_backend_tensor_get(result, data, 0, ggml_nbytes(result)); From a9681febd65cbd3f372badc5f4a4d8bc1336d2d9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 20 Jan 2024 12:26:49 +0200 Subject: [PATCH 06/58] ggml : online attention (CPU) --- ggml-metal.m | 8 +- ggml-metal.metal | 3 +- ggml.c | 249 ++++++++++++++++++------------------- ggml.h | 5 + llama.cpp | 124 ++++++++++-------- tests/test-backend-ops.cpp | 14 +-- 6 files changed, 218 insertions(+), 185 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 6d88d5c36a8ad..4d85dd3ddb319 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2207,9 +2207,15 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:20]; [encoder setBytes:&scale length:sizeof( float) atIndex:21]; + const int nwarps = 4; + + // each warp needs n_embd_head elements + GGML_ASSERT(nwarps*ne00*sizeof(float) <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:nwarps*ne00*sizeof(float) atIndex:0]; + const int nth = MIN(1024, ne0); - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index 28847794cb5d8..a1e1755a3a605 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1981,7 +1981,8 @@ kernel void kernel_flash_attn_ext_f16( constant uint64_t & nb1, constant uint64_t & nb2, constant uint64_t & nb3, - constant float & scale, + constant float & scale, + threadgroup float * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { diff --git a/ggml.c b/ggml.c index 9cf4784ce4759..e64a328fadb1f 100644 --- a/ggml.c +++ b/ggml.c @@ -817,7 +817,7 @@ do { \ #if defined(__F16C__) // the _mm256_cvt intrinsics require F16C -#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x))) +#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) #define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0)) #else static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) { @@ -1323,6 +1323,37 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float #endif } +inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] += GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i])*v); + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] += GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i])*v); + } +#endif +} + // xs and vs are byte strides of x and v inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) { @@ -1407,6 +1438,35 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { #endif } +inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_MUL(ay[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v); + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v); + } +#endif +} + inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s); } inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } @@ -5704,8 +5764,9 @@ struct ggml_tensor * ggml_flash_attn_ext( is_node = true; } - //struct ggml_tensor * result = ggml_dup_tensor(ctx, q); - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne); + // permute(0, 2, 1, 3) + int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, ne); float params[] = { scale }; ggml_set_op_params(result, params, sizeof(params)); @@ -13281,12 +13342,9 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int64_t D = neq0; const int64_t N = neq1; const int64_t P = nek1 - N; - const int64_t M = P + N; - - const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); GGML_ASSERT(ne0 == D); - GGML_ASSERT(ne1 == N); + GGML_ASSERT(ne2 == N); GGML_ASSERT(P >= 0); GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t)); @@ -13295,11 +13353,11 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(neq0 == D); GGML_ASSERT(nek0 == D); - GGML_ASSERT(nev1 == D); + GGML_ASSERT(nev0 == D); GGML_ASSERT(neq1 == N); GGML_ASSERT(nek1 == N + P); - GGML_ASSERT(nev1 == D); + GGML_ASSERT(nev0 == D); // dst cannot be transposed or permuted GGML_ASSERT(nb0 == sizeof(float)); @@ -13339,151 +13397,87 @@ static void ggml_compute_forward_flash_attn_ext_f16( //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { // q indices const int iq3 = ir/(neq2*neq1); const int iq2 = (ir - iq3*neq2*neq1)/neq1; const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); - float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32); + float S = 0.0f; + float M = -INFINITY; - for (int i = M; i < Mup; ++i) { - S[i] = -INFINITY; - } + float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32); + ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D); - if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) { - for (int64_t ic = 0; ic < nek1; ++ic) { - // k indices - const int ik3 = iq3 / rk3; - const int ik2 = iq2 / rk2; - const int ik1 = ic; + memset(V16, 0, D*sizeof(ggml_fp16_t)); - // S indices - const int i1 = ik1; + const float * mp = mask ? (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]) : NULL; - ggml_vec_dot_f16(neq0, - S + i1, - (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), - (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); - } - } else { - for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { - // k indices - const int ik3 = iq3 / rk3; - const int ik2 = iq2 / rk2; - const int ik1 = ic; + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; - // S indices - const int i1 = ik1; + // v indices + const int iv2 = iq2 / rv2; + const int iv3 = iq3 / rv3; - ggml_vec_dot_f16_unroll(neq0, nbk1, - S + i1, - ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), - (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + // online softmax / attention + // loop over n_kv and n_head_kv + // ref: https://arxiv.org/pdf/2112.05682.pdf + for (int64_t ic = 0; ic < nek1; ++ic) { + const float mv = mp ? mp[ic] : 0.0f; + if (mv == -INFINITY) { + continue; } - } - // scale - ggml_vec_scale_f32(nek1, S, scale); + float s; - if (mask) { - const float * mp = (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]); - ggml_vec_acc_f32(M, S, mp); - } + ggml_vec_dot_f16(D, + &s, + (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), + (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); - // softmax - // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero. - // dont forget to set their S values to zero - { - float max = -INFINITY; - ggml_vec_max_f32(M, &max, S); + s = s*scale + mv; - ggml_float sum = 0.0; - { -#ifdef GGML_SOFT_MAX_ACCELERATE - max = -max; - vDSP_vsadd(S, 1, &max, S, 1, Mup); - vvexpf(S, S, &Mup); - ggml_vec_sum_f32(Mup, &sum, S); -#else - uint16_t scvt[GGML_SOFT_MAX_UNROLL]; - ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; + const float Mold = M; - for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { - float * SS = S + i; + float ms = 1.0f; + float vs = 1.0f; - for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { - if (SS[j] == -INFINITY) { - SS[j] = 0.0f; - } else { - ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); - memcpy(&scvt[j], &s, sizeof(uint16_t)); - const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]); - sump[j] += (ggml_float)val; - SS[j] = val; - } - } - } + if (s > M) { + M = s; + ms = expf(Mold - M); - for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { - sum += sump[i]; - } -#endif + // V = V*expf(Mold - M) + ggml_vec_scale_f16(D, V16, ms); + } else { + vs = expf(s - M); } - assert(sum > 0.0); + const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); - sum = 1.0/sum; - ggml_vec_scale_f32(M, S, sum); + // V += v*expf(s - M) + ggml_vec_mad_f16(D, V16, v16, vs); -#ifndef NDEBUG - for (int i = 0; i < M; ++i) { - assert(!isnan(S[i])); - assert(!isinf(S[i])); - } -#endif + S = S*ms + vs; } - ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup); - - for (int64_t i = 0; i < M; i++) { - S16[i] = GGML_FP32_TO_FP16(S[i]); + // V /= S + for (int64_t d = 0; d < D; ++d) { + V32[d] = GGML_FP16_TO_FP32(V16[d])/S; } - // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16). - if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) { - for (int64_t ic = 0; ic < nev1; ++ic) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; - - // v indices - const int iv2 = iq2 / rv2; - const int iv3 = iq3 / rv3; - - ggml_vec_dot_f16(nev0, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), - S16); - } - } else { - for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; - // v indices - const int iv2 = iq2 / rv2; - const int iv3 = iq3 / rv3; + // original + //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float)); - ggml_vec_dot_f16_unroll(nev0, nbv1, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), - S16); - } - } + // permute(0, 2, 1, 3) + memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1); } } @@ -17069,7 +17063,6 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; } break; case GGML_OP_FLASH_ATTN: - case GGML_OP_FLASH_ATTN_EXT: { const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL); @@ -17081,6 +17074,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 } } break; + case GGML_OP_FLASH_ATTN_EXT: + { + const int64_t ne00 = node->src[0]->ne[0]; // D + + cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size + } break; case GGML_OP_FLASH_FF: { if (node->src[1]->type == GGML_TYPE_F32) { diff --git a/ggml.h b/ggml.h index d76fe9d5c48c9..7bca02f2a2c48 100644 --- a/ggml.h +++ b/ggml.h @@ -1620,6 +1620,11 @@ extern "C" { struct ggml_tensor * v, bool masked); + // q: [n_embd, n_batch, n_head, 1] + // k: [n_embd, n_kv, n_head_kv, 1] + // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !! + // mask: [n_kv, n_batch, 1, 1] + // res: [n_embd, n_head, n_batch, 1] !! permuted !! GGML_API struct ggml_tensor * ggml_flash_attn_ext( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/llama.cpp b/llama.cpp index f0a63afef0087..4e6c9f9cc75ea 100644 --- a/llama.cpp +++ b/llama.cpp @@ -95,6 +95,8 @@ #define LLAMA_MAX_NODES 8192 #define LLAMA_MAX_EXPERTS 8 +#define LLAMA_FLASH_ATTN + // // logging // @@ -4167,23 +4169,34 @@ static void llm_build_kv_store( const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); - // compute the transposed [n_tokens, n_embd] V matrix - struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens)); - //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed - cb(v_cur_t, "v_cur_t", il); - struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa, (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head); cb(k_cache_view, "k_cache_view", il); + // important: storing RoPE-ed version of K in the KV cache! + ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); + +#if defined(LLAMA_FLASH_ATTN) + // NOTE: the V cache is not transposed when using FLASH attention !! + struct ggml_tensor * v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, + (ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa))*kv_head); + cb(v_cache_view, "v_cache_view", il); + + ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view)); + + GGML_UNUSED(n_ctx); +#else + // compute the transposed [n_tokens, n_embd] V matrix + //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens)); + struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed + cb(v_cur_t, "v_cur_t", il); + struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, ( n_ctx)*ggml_element_size(kv.v_l[il]), (kv_head)*ggml_element_size(kv.v_l[il])); - cb(v_cache_view, "v_cache_view", il); - // important: storing RoPE-ed version of K in the KV cache! - ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view)); +#endif } static struct ggml_tensor * llm_build_norm( @@ -4343,68 +4356,77 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(k, "k", il); - // split cached v into n_head heads + struct ggml_tensor * cur; + +#if defined(LLAMA_FLASH_ATTN) + // split cached v into n_head heads (not transposed) struct ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il], - n_kv, n_embd_head_v, n_head_kv, - ggml_element_size(kv.v_l[il])*n_ctx, - ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v, + n_embd_head_v, n_kv, n_head_kv, + ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv.v_l[il]->type, n_embd_head_k), 0); cb(v, "v", il); - // TODO: determine if we can use flash attention - const bool supports_flash_attn = true; + cur = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); + //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); + //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); + //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); + //printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]); + //printf("r: %4d %4d %4d %4d\n", kqv->ne[0], kqv->ne[1], kqv->ne[2], kqv->ne[3]); - struct ggml_tensor * kqv; + cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens); +#else + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); + cb(kq, "kq", il); - if (supports_flash_attn) { - //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); - //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); - //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); - //printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]); - kqv = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); - } else { - struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); - cb(kq, "kq", il); + if (model.arch == LLM_ARCH_PHI2) { + // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs + // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + } - if (model.arch == LLM_ARCH_PHI2) { - // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs - // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 - ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - } + if (max_alibi_bias > 0.0f) { + // temporary branch until we figure out how to handle ggml_alibi through ggml_add + kq = ggml_scale(ctx, kq, kq_scale); + cb(kq, "kq_scaled", il); if (max_alibi_bias > 0.0f) { - // temporary branch until we figure out how to handle ggml_alibi through ggml_add - kq = ggml_scale(ctx, kq, kq_scale); - cb(kq, "kq_scaled", il); + // TODO: n_head or n_head_kv + // TODO: K-shift is likely not working + // TODO: change to ggml_add + kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias); + cb(kq, "kq_scaled_alibi", il); + } - if (max_alibi_bias > 0.0f) { - // TODO: n_head or n_head_kv - // TODO: K-shift is likely not working - // TODO: change to ggml_add - kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias); - cb(kq, "kq_scaled_alibi", il); - } + kq = ggml_add(ctx, kq, kq_mask); + cb(kq, "kq_masked", il); - kq = ggml_add(ctx, kq, kq_mask); - cb(kq, "kq_masked", il); + kq = ggml_soft_max(ctx, kq); + cb(kq, "kq_soft_max", il); + } else { + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale); + cb(kq, "kq_soft_max_ext", il); + } - kq = ggml_soft_max(ctx, kq); - cb(kq, "kq_soft_max", il); - } else { - kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale); - cb(kq, "kq_soft_max_ext", il); - } + // split cached v into n_head heads (transposed) + struct ggml_tensor * v = + ggml_view_3d(ctx, kv.v_l[il], + n_kv, n_embd_head_v, n_head_kv, + ggml_element_size(kv.v_l[il])*n_ctx, + ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v, + 0); + cb(v, "v", il); - kqv = ggml_mul_mat(ctx, v, kq); - cb(kqv, "kqv", il); - } + struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); + cb(kqv, "kqv", il); struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); cb(kqv_merged, "kqv_merged", il); - struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); + cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); cb(cur, "kqv_merged_cont", il); +#endif cur = ggml_mul_mat(ctx, wo, cur); if (wo_b) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 5693c2197c7c5..a56c0d6c59a64 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1390,21 +1390,21 @@ struct test_flash_attn_ext : public test_case { const int64_t hs; // head size const int64_t nh; // num heads const int64_t kv; // kv size - const int64_t nt; // tokens + const int64_t nb; // batch size std::string vars() override { - return VARS_TO_STR5(typeq, hs, nh, kv, nt); + return VARS_TO_STR5(typeq, hs, nh, kv, nb); } test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, - int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nt = 8) - : typeq(typeq), hs(hs), nh(nh), kv(kv), nt(nt) {} + int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) + : typeq(typeq), hs(hs), nh(nh), kv(kv), nb(nb) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nt, nh, 1); + ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nb, nh, 1); ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); - ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); - ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nt, 1, 1); + ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); + ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1); ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs)); return out; } From 1173f49c3bbe30810af4aeb77219eba7e05f658d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 20 Jan 2024 17:32:28 +0200 Subject: [PATCH 07/58] metal : initial implementation --- ggml-metal.m | 75 +++++++++++++------- ggml-metal.metal | 138 ++++++++++++++++++++++++++++++++++--- ggml.c | 2 +- tests/test-backend-ops.cpp | 4 ++ 4 files changed, 183 insertions(+), 36 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 4d85dd3ddb319..556c53482a75e 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -278,6 +278,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ NSURL * libURL = [NSURL fileURLWithPath:libPath]; GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]); ctx->library = [ctx->device newLibraryWithURL:libURL error:&error]; + if (error) { + GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + return NULL; + } } else { GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__); @@ -316,13 +320,12 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ //[options setFastMathEnabled:false]; ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error]; + if (error) { + GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + return NULL; + } } } - - if (error) { - GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); - return NULL; - } } // print MTL GPU family: @@ -396,6 +399,9 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \ kernel->function = [ctx->library newFunctionWithName:@"kernel_"#name]; \ kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:kernel->function error:&error]; \ + GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ + (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \ + (int) kernel->pipeline.threadExecutionWidth); \ if (error) { \ GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ return NULL; \ @@ -2171,12 +2177,28 @@ static bool ggml_metal_graph_compute( struct ggml_tensor * src2 = gf->nodes[i]->src[2]; struct ggml_tensor * src3 = gf->nodes[i]->src[3]; + GGML_ASSERT(ggml_are_same_shape(src1, src2)); + size_t offs_src2 = 0; size_t offs_src3 = 0; - id id_src2 = src2 ? ggml_metal_get_buffer(ctx, src2, &offs_src2) : nil; + GGML_ASSERT(src2); + id id_src2 = ggml_metal_get_buffer(ctx, src2, &offs_src2); + id id_src3 = src3 ? ggml_metal_get_buffer(ctx, src3, &offs_src3) : nil; + const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); + const int64_t ne31 = src3 ? src3->ne[1] : 0; + const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); + const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33); + + const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30); + const uint64_t nb31 = src3 ? src3->nb[1] : 0; + const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32); + const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33); + + const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); + float scale; memcpy(&scale, dst->op_params, sizeof(float)); @@ -2197,25 +2219,28 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10]; [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11]; [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:14]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:15]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:16]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:20]; - [encoder setBytes:&scale length:sizeof( float) atIndex:21]; - - const int nwarps = 4; - - // each warp needs n_embd_head elements - GGML_ASSERT(nwarps*ne00*sizeof(float) <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:nwarps*ne00*sizeof(float) atIndex:0]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)]; + [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15]; + [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16]; + [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20]; + [encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21]; + [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; + [encoder setBytes:&scale length:sizeof( float) atIndex:27]; + + const int nwarps = 1; + + GGML_ASSERT(2*32*nwarps*ne00*sizeof(float) <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:2*32*nwarps*ne00*sizeof(float) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 31)/32, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index a1e1755a3a605..5986bcb427f4b 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1960,10 +1960,10 @@ kernel void kernel_leaky_relu_f32( } kernel void kernel_flash_attn_ext_f16( - device const half * q, - device const half * k, - device const half * v, - device const float * mask, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, device float * dst, constant int64_t & ne00, constant int64_t & ne01, @@ -1973,20 +1973,138 @@ kernel void kernel_flash_attn_ext_f16( constant uint64_t & nb01, constant uint64_t & nb02, constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne31, + constant uint64_t & nb31, constant int64_t & ne0, constant int64_t & ne1, constant int64_t & ne2, constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, constant float & scale, threadgroup float * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - // TODO: implement + uint3 ntg[[threads_per_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + const int64_t iq3 = tgpig[2]; + const int64_t iq2 = tgpig[1]; + const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; + + if (iq1 >= ne01) { + return; + } + + const int64_t D = ne00; + + // TODO: can we move this to the stack? + threadgroup half * V16 = (threadgroup half *) (shared + (2*sgitg*N_SIMDWIDTH + tiisg)*D); + + // initialize with zeros + for (int64_t d = 0; d < D; ++d) { + V16[d] = 0.0h; + } + + threadgroup half * pq = (threadgroup half *) (shared + (2*sgitg*N_SIMDWIDTH + N_SIMDWIDTH)*D + tiisg*D); + + half S = 0.0h; + half M = -INFINITY; + + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; + + device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr; + + // assume K and V are same shape + const int64_t ne22 = ne12; + const int64_t ne23 = ne13; + + const uint64_t nb21 = nb11; + const uint64_t nb22 = nb12; + const uint64_t nb23 = nb13; + + // broadcast + const int64_t rk2 = ne02/ne12; + const int64_t rk3 = ne03/ne13; + + const int64_t rv2 = ne02/ne22; + const int64_t rv3 = ne03/ne23; + + // k indices + const int64_t ik2 = iq2 / rk2; + const int64_t ik3 = iq3 / rk3; + + // v indices + const int64_t iv2 = iq2 / rv2; + const int64_t iv3 = iq3 / rv3; + + // load Q to shared memory + for (int64_t d = 0; d < D; ++d) { + pq[d] = ((device const half *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; + } + + for (int64_t ic = 0; ic < ne11; ++ic) { + const half mv = mp ? mp[ic] : 0.0h; + if (mv == -INFINITY) { + continue; + } + + half s = 0.0f; + + //device const half * pq = (device const half *) ((device char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); + device const half * pk = (device const half *) ((device char *) k + ( ic*nb11 + ik2*nb12 + ik3*nb13)); + + for (int64_t d = 0; d < D; ++d) { + s += pk[d] * pq[d]; + } + + s = s*scale + mv; + + const half Mold = M; + + half ms = 1.0f; + half vs = 1.0f; + + if (s > M) { + M = s; + ms = exp(Mold - M); + + // V = V*exp(Mold - M) + for (int64_t d = 0; d < D; ++d) { + V16[d] *= ms; + } + } else { + vs = exp(s - M); + } + + device const half * pv = (device const half *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); + + // V += v*exp(s - M) + for (int64_t d = 0; d < D; ++d) { + V16[d] += pv[d] * vs; + } + + S = S*ms + vs; + } + + for (int64_t d = 0; d < D; ++d) { + V16[d] /= S; + } + + // dst indices + const int64_t i1 = iq1; + const int64_t i2 = iq2; + const int64_t i3 = iq3; + + for (int64_t d = 0; d < D; ++d) { + dst[(i3*ne2*ne1 + i2 + i1*ne1)*D + d] = V16[d]; + } } kernel void kernel_cpy_f16_f16( diff --git a/ggml.c b/ggml.c index e64a328fadb1f..10df03c9c619b 100644 --- a/ggml.c +++ b/ggml.c @@ -13419,8 +13419,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int ik2 = iq2 / rk2; // v indices - const int iv2 = iq2 / rv2; const int iv3 = iq3 / rv3; + const int iv2 = iq2 / rv2; // online softmax / attention // loop over n_kv and n_head_kv diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index a56c0d6c59a64..51a33c662da56 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1396,6 +1396,10 @@ struct test_flash_attn_ext : public test_case { return VARS_TO_STR5(typeq, hs, nh, kv, nb); } + double max_nmse_err() override { + return 5e-4; + } + test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) : typeq(typeq), hs(hs), nh(nh), kv(kv), nb(nb) {} From 528da7515ef874ab1188ab8f691c36d3e9e0cb20 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 11:13:24 +0200 Subject: [PATCH 08/58] metal : f16 precision --- ggml-metal.m | 6 ++++-- ggml-metal.metal | 40 ++++++++++++++++++++++------------------ 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 556c53482a75e..e67a7c4ef892b 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2237,8 +2237,10 @@ static bool ggml_metal_graph_compute( const int nwarps = 1; - GGML_ASSERT(2*32*nwarps*ne00*sizeof(float) <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:2*32*nwarps*ne00*sizeof(float) atIndex:0]; + const size_t shalf = sizeof(float)/2; + + GGML_ASSERT(2*32*nwarps*ne00*shalf <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:2*32*nwarps*ne00*shalf atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 31)/32, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; } break; diff --git a/ggml-metal.metal b/ggml-metal.metal index 5986bcb427f4b..e4e89b5b3f7bf 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1988,7 +1988,7 @@ kernel void kernel_flash_attn_ext_f16( constant int64_t & ne2, constant int64_t & ne3, constant float & scale, - threadgroup float * shared [[threadgroup(0)]], + threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]], @@ -2003,16 +2003,17 @@ kernel void kernel_flash_attn_ext_f16( } const int64_t D = ne00; + const int64_t D4 = D/4; // TODO: can we move this to the stack? - threadgroup half * V16 = (threadgroup half *) (shared + (2*sgitg*N_SIMDWIDTH + tiisg)*D); + threadgroup half4 * V16 = (threadgroup half4 *) (shared + (2*sgitg*N_SIMDWIDTH + tiisg)*D); // initialize with zeros - for (int64_t d = 0; d < D; ++d) { + for (int64_t d = 0; d < D4; ++d) { V16[d] = 0.0h; } - threadgroup half * pq = (threadgroup half *) (shared + (2*sgitg*N_SIMDWIDTH + N_SIMDWIDTH)*D + tiisg*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + (2*sgitg*N_SIMDWIDTH + N_SIMDWIDTH)*D + tiisg*D); half S = 0.0h; half M = -INFINITY; @@ -2045,8 +2046,8 @@ kernel void kernel_flash_attn_ext_f16( const int64_t iv3 = iq3 / rv3; // load Q to shared memory - for (int64_t d = 0; d < D; ++d) { - pq[d] = ((device const half *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; + for (int64_t d = 0; d < D4; ++d) { + pq4[d] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; } for (int64_t ic = 0; ic < ne11; ++ic) { @@ -2055,15 +2056,16 @@ kernel void kernel_flash_attn_ext_f16( continue; } - half s = 0.0f; + half4 s4 = 0.0f; - //device const half * pq = (device const half *) ((device char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); - device const half * pk = (device const half *) ((device char *) k + ( ic*nb11 + ik2*nb12 + ik3*nb13)); + device const half4 * pk4 = (device const half4 *) ((device char *) k + ( ic*nb11 + ik2*nb12 + ik3*nb13)); - for (int64_t d = 0; d < D; ++d) { - s += pk[d] * pq[d]; + for (int64_t d = 0; d < D4; ++d) { + s4 += pk4[d] * pq4[d]; } + half s = s4.x + s4.y + s4.z + s4.w; + s = s*scale + mv; const half Mold = M; @@ -2076,24 +2078,24 @@ kernel void kernel_flash_attn_ext_f16( ms = exp(Mold - M); // V = V*exp(Mold - M) - for (int64_t d = 0; d < D; ++d) { + for (int64_t d = 0; d < D4; ++d) { V16[d] *= ms; } } else { vs = exp(s - M); } - device const half * pv = (device const half *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); + device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); // V += v*exp(s - M) - for (int64_t d = 0; d < D; ++d) { - V16[d] += pv[d] * vs; + for (int64_t d = 0; d < D4; ++d) { + V16[d] += pv4[d] * vs; } S = S*ms + vs; } - for (int64_t d = 0; d < D; ++d) { + for (int64_t d = 0; d < D4; ++d) { V16[d] /= S; } @@ -2102,8 +2104,10 @@ kernel void kernel_flash_attn_ext_f16( const int64_t i2 = iq2; const int64_t i3 = iq3; - for (int64_t d = 0; d < D; ++d) { - dst[(i3*ne2*ne1 + i2 + i1*ne1)*D + d] = V16[d]; + device float4 * dst4 = (device float4 *) dst; + + for (int64_t d = 0; d < D4; ++d) { + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + d] = (float4) V16[d]; } } From 52ae085750afd37affc4ed18fe092d92c9ccdc5f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 11:38:17 +0200 Subject: [PATCH 09/58] metal : reduce branches --- ggml-metal.metal | 30 ++++++++---------------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index e4e89b5b3f7bf..f3a7efafa6613 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2056,40 +2056,26 @@ kernel void kernel_flash_attn_ext_f16( continue; } - half4 s4 = 0.0f; + device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); + device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); - device const half4 * pk4 = (device const half4 *) ((device char *) k + ( ic*nb11 + ik2*nb12 + ik3*nb13)); + half4 s4 = 0.0h; for (int64_t d = 0; d < D4; ++d) { s4 += pk4[d] * pq4[d]; } - half s = s4.x + s4.y + s4.z + s4.w; - - s = s*scale + mv; + half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; const half Mold = M; - half ms = 1.0f; - half vs = 1.0f; - - if (s > M) { - M = s; - ms = exp(Mold - M); - - // V = V*exp(Mold - M) - for (int64_t d = 0; d < D4; ++d) { - V16[d] *= ms; - } - } else { - vs = exp(s - M); - } + M = max(M, s); - device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); + const half ms = exp(Mold - M); + const half vs = exp(s - M); - // V += v*exp(s - M) for (int64_t d = 0; d < D4; ++d) { - V16[d] += pv4[d] * vs; + V16[d] = V16[d]*ms + pv4[d]*vs; } S = S*ms + vs; From b97325800a7727244e737715fa7b5e2bc41afb21 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 12:01:55 +0200 Subject: [PATCH 10/58] metal : specialize for head size --- ggml-metal.m | 259 +++++++++++++++++++++++++---------------------- ggml-metal.metal | 42 +++++++- 2 files changed, 179 insertions(+), 122 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index e67a7c4ef892b..046643146b3f3 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -147,7 +147,9 @@ GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, @@ -412,125 +414,127 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ // simd_sum and simd_max requires MTLGPUFamilyApple7 - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16, flash_attn_ext_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); } return ctx; @@ -2172,6 +2176,7 @@ static bool ggml_metal_graph_compute( } break; case GGML_OP_FLASH_ATTN_EXT: { + GGML_ASSERT(ne00 % 4 == 0); GGML_ASSERT(src0->type == GGML_TYPE_F16); struct ggml_tensor * src2 = gf->nodes[i]->src[2]; @@ -2202,7 +2207,19 @@ static bool ggml_metal_graph_compute( float scale; memcpy(&scale, dst->op_params, sizeof(float)); - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16].pipeline; + id pipeline = nil; + + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ASSERT(false && "add template specialization for this size"); + } + } // TODO: extend if necessary [encoder setComputePipelineState:pipeline]; diff --git a/ggml-metal.metal b/ggml-metal.metal index f3a7efafa6613..d97952f2b0871 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1959,6 +1959,43 @@ kernel void kernel_leaky_relu_f32( dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; } +typedef void (flash_attn_ext_f16_t)( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne31, + constant uint64_t & nb31, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant float & scale, + threadgroup half * shared, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]); + +template // head size kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2002,7 +2039,6 @@ kernel void kernel_flash_attn_ext_f16( return; } - const int64_t D = ne00; const int64_t D4 = D/4; // TODO: can we move this to the stack? @@ -2097,6 +2133,10 @@ kernel void kernel_flash_attn_ext_f16( } } +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, From 8cde449b8be4e481db2a8790d9320c743b3ed65e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 12:23:22 +0200 Subject: [PATCH 11/58] wip : 8 rows per simd group --- ggml-metal.m | 10 +-- ggml-metal.metal | 173 ++++++++++++++++++++++++++++++++++++----------- 2 files changed, 139 insertions(+), 44 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 046643146b3f3..0b1119c4eb467 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2252,14 +2252,14 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int nwarps = 1; + const int64_t nwarps = 2; - const size_t shalf = sizeof(float)/2; + const size_t smem = nwarps*(2*8*nwarps*ne00 + 128)*(sizeof(float)/2); - GGML_ASSERT(2*32*nwarps*ne00*shalf <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:2*32*nwarps*ne00*shalf atIndex:0]; + GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 31)/32, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + 8*nwarps - 1)/(8*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index d97952f2b0871..789b19bad6b93 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2031,33 +2031,20 @@ kernel void kernel_flash_attn_ext_f16( uint3 ntg[[threads_per_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]; - const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; - - if (iq1 >= ne01) { - return; - } + //const int64_t iq3 = tgpig[2]; + //const int64_t iq2 = tgpig[1]; + //const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; - const int64_t D4 = D/4; + const uint nsg = ntg.x/N_SIMDWIDTH; // number of simdgroups - // TODO: can we move this to the stack? - threadgroup half4 * V16 = (threadgroup half4 *) (shared + (2*sgitg*N_SIMDWIDTH + tiisg)*D); + const int64_t iq3 = tgpig[2]; + const int64_t iq2 = tgpig[1]*(8*nsg) + 8*sgitg + tiisg/4; + const int64_t iq1 = tgpig[0]; - // initialize with zeros - for (int64_t d = 0; d < D4; ++d) { - V16[d] = 0.0h; + if (iq2 >= ne02) { + return; } - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + (2*sgitg*N_SIMDWIDTH + N_SIMDWIDTH)*D + tiisg*D); - - half S = 0.0h; - half M = -INFINITY; - - const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - - device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr; - // assume K and V are same shape const int64_t ne22 = ne12; const int64_t ne23 = ne13; @@ -2081,11 +2068,97 @@ kernel void kernel_flash_attn_ext_f16( const int64_t iv2 = iq2 / rv2; const int64_t iv3 = iq3 / rv3; - // load Q to shared memory - for (int64_t d = 0; d < D4; ++d) { - pq4[d] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; + + device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr; + +// const int64_t D4 = D/4; +// +// // TODO: can we move this to the stack? +// threadgroup half4x4 * V16 = (threadgroup half4x4 *) (shared); +// +// // initialize with zeros +// for (int64_t d = 0; d < D4; ++d) { +// +// } +// +// threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 4*D); +// +// // load Q to shared memory +// for (int64_t d = 0; d < D4; ++d) { +// pq4[d] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; +// } +// +// half S = 0.0h; +// half M = -INFINITY; +// +// for (int64_t ic = 0; ic < ne11; ++ic) { +// const half mv = mp ? mp[ic] : 0.0h; +// if (mv == -INFINITY) { +// continue; +// } +// +// device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); +// device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); +// +// half4 s4 = 0.0h; +// +// for (int64_t d = 0; d < D4; ++d) { +// s4 += pk4[d] * pq4[d]; +// } +// +// half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; +// +// const half Mold = M; +// +// M = max(M, s); +// +// const half ms = exp(Mold - M); +// const half vs = exp(s - M); +// +// for (int64_t d = 0; d < D4; ++d) { +// V16[d] = V16[d]*ms + pv4[d]*vs; +// } +// +// S = S*ms + vs; +// } +// +// for (int64_t d = 0; d < D4; ++d) { +// V16[d] /= S; +// } +// +// // dst indices +// const int64_t i1 = iq1; +// const int64_t i2 = iq2; +// const int64_t i3 = iq3; +// +// device float4 * dst4 = (device float4 *) dst; +// +// for (int64_t d = 0; d < D4; ++d) { +// dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + d] = (float4) V16[d]; +// } + + const int64_t D4 = D/4; + + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) ); + threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) + 8*D); + threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) + 16*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(16*D + 128) + 16*D); + + const uint tiih = tiisg%4; // thread index in head + const uint hiisg = tiisg/4; // head index in simdgroup + + // load 8 heads from Q to shared memory + for (int64_t i = 0; i < D4/4; ++i) { + pq4[hiisg*D4 + 4*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[4*i + tiih]; + ps4[hiisg*D4 + 4*i + tiih] = 0.0h; } + simdgroup_barrier(mem_flags::mem_threadgroup); + + half S = 0.0h; + half M = -INFINITY; + for (int64_t ic = 0; ic < ne11; ++ic) { const half mv = mp ? mp[ic] : 0.0h; if (mv == -INFINITY) { @@ -2097,30 +2170,52 @@ kernel void kernel_flash_attn_ext_f16( half4 s4 = 0.0h; - for (int64_t d = 0; d < D4; ++d) { - s4 += pk4[d] * pq4[d]; + for (int64_t i = 0; i < D4/4; ++i) { + s4 += pk4[4*i + tiih] * pq4[hiisg*D4 + 4*i + tiih]; } - half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; + ss4[hiisg*4 + tiih] = s4; + + simdgroup_barrier(mem_flags::mem_threadgroup); + + if (tiih == 0) { + s4 = ss4[4*hiisg + 0] + ss4[4*hiisg + 1] + ss4[4*hiisg + 2] + ss4[4*hiisg + 3]; + + half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; - const half Mold = M; + const half Mold = M; - M = max(M, s); + M = max(M, s); - const half ms = exp(Mold - M); - const half vs = exp(s - M); + const half ms = exp(Mold - M); + const half vs = exp(s - M); - for (int64_t d = 0; d < D4; ++d) { - V16[d] = V16[d]*ms + pv4[d]*vs; + S = S*ms + vs; + + ss[2*hiisg + 0] = ms; + ss[2*hiisg + 1] = vs; } - S = S*ms + vs; + simdgroup_barrier(mem_flags::mem_threadgroup); + + const half ms = ss[2*hiisg + 0]; + const half vs = ss[2*hiisg + 1]; + + for (int64_t i = 0; i < D4/4; ++i) { + ps4[hiisg*D4 + 4*i + tiih] = ps4[hiisg*D4 + 4*i + tiih]*ms + pv4[4*i + tiih]*vs; + } } - for (int64_t d = 0; d < D4; ++d) { - V16[d] /= S; + simdgroup_barrier(mem_flags::mem_threadgroup); + + if (tiih == 0) { + for (int64_t i = 0; i < D4; ++i) { + ps4[hiisg*D4 + i] /= S; + } } + simdgroup_barrier(mem_flags::mem_threadgroup); + // dst indices const int64_t i1 = iq1; const int64_t i2 = iq2; @@ -2128,8 +2223,8 @@ kernel void kernel_flash_attn_ext_f16( device float4 * dst4 = (device float4 *) dst; - for (int64_t d = 0; d < D4; ++d) { - dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + d] = (float4) V16[d]; + for (int64_t i = 0; i < D4/4; ++i) { + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + 4*i + tiih] = (float4) ps4[hiisg*D4 + 4*i + tiih]; } } From f31955f5d12da67f35aa459996a171975fdf269b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 18:01:28 +0200 Subject: [PATCH 12/58] wip : 4 rows per simd group --- ggml-metal.m | 6 +++--- ggml-metal.metal | 39 +++++++++++++++++++++------------------ 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 0b1119c4eb467..abb96d6ec6e44 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2252,14 +2252,14 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nwarps = 2; + const int64_t nwarps = 4; - const size_t smem = nwarps*(2*8*nwarps*ne00 + 128)*(sizeof(float)/2); + const size_t smem = nwarps*(2*4*ne00 + 128)*(sizeof(float)/2); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + 8*nwarps - 1)/(8*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + 4*nwarps - 1)/(4*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index 789b19bad6b93..6fdd7fdad4326 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2038,7 +2038,7 @@ kernel void kernel_flash_attn_ext_f16( const uint nsg = ntg.x/N_SIMDWIDTH; // number of simdgroups const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]*(8*nsg) + 8*sgitg + tiisg/4; + const int64_t iq2 = tgpig[1]*(4*nsg) + 4*sgitg + tiisg/8; const int64_t iq1 = tgpig[0]; if (iq2 >= ne02) { @@ -2140,18 +2140,18 @@ kernel void kernel_flash_attn_ext_f16( const int64_t D4 = D/4; - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) ); - threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) + 8*D); - threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) + 16*D); - threadgroup half * ss = (threadgroup half *) (shared + sgitg*(16*D + 128) + 16*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) ); + threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) + 4*D); + threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) + 2*4*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(2*4*D + 128) + 2*4*D); - const uint tiih = tiisg%4; // thread index in head - const uint hiisg = tiisg/4; // head index in simdgroup + const uint tiih = tiisg%8; // thread index in head + const uint hiisg = tiisg/8; // head index in simdgroup // load 8 heads from Q to shared memory - for (int64_t i = 0; i < D4/4; ++i) { - pq4[hiisg*D4 + 4*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[4*i + tiih]; - ps4[hiisg*D4 + 4*i + tiih] = 0.0h; + for (int64_t i = 0; i < D4/8; ++i) { + pq4[hiisg*D4 + 8*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[8*i + tiih]; + ps4[hiisg*D4 + 8*i + tiih] = 0.0h; } simdgroup_barrier(mem_flags::mem_threadgroup); @@ -2170,16 +2170,18 @@ kernel void kernel_flash_attn_ext_f16( half4 s4 = 0.0h; - for (int64_t i = 0; i < D4/4; ++i) { - s4 += pk4[4*i + tiih] * pq4[hiisg*D4 + 4*i + tiih]; +#pragma unroll(D4/8) + for (int64_t i = 0; i < D4/8; ++i) { + s4 += pk4[8*i + tiih] * pq4[hiisg*D4 + 8*i + tiih]; } - ss4[hiisg*4 + tiih] = s4; + ss4[hiisg*8 + tiih] = s4; simdgroup_barrier(mem_flags::mem_threadgroup); if (tiih == 0) { - s4 = ss4[4*hiisg + 0] + ss4[4*hiisg + 1] + ss4[4*hiisg + 2] + ss4[4*hiisg + 3]; + s4 = ss4[8*hiisg + 0] + ss4[8*hiisg + 1] + ss4[8*hiisg + 2] + ss4[8*hiisg + 3] + + ss4[8*hiisg + 4] + ss4[8*hiisg + 5] + ss4[8*hiisg + 6] + ss4[8*hiisg + 7]; half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; @@ -2201,8 +2203,9 @@ kernel void kernel_flash_attn_ext_f16( const half ms = ss[2*hiisg + 0]; const half vs = ss[2*hiisg + 1]; - for (int64_t i = 0; i < D4/4; ++i) { - ps4[hiisg*D4 + 4*i + tiih] = ps4[hiisg*D4 + 4*i + tiih]*ms + pv4[4*i + tiih]*vs; +#pragma unroll(D4/8) + for (int64_t i = 0; i < D4/8; ++i) { + ps4[hiisg*D4 + 8*i + tiih] = ps4[hiisg*D4 + 8*i + tiih]*ms + pv4[8*i + tiih]*vs; } } @@ -2223,8 +2226,8 @@ kernel void kernel_flash_attn_ext_f16( device float4 * dst4 = (device float4 *) dst; - for (int64_t i = 0; i < D4/4; ++i) { - dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + 4*i + tiih] = (float4) ps4[hiisg*D4 + 4*i + tiih]; + for (int64_t i = 0; i < D4/8; ++i) { + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + 8*i + tiih] = (float4) ps4[hiisg*D4 + 8*i + tiih]; } } From a4b6341c7b2a1977c29e79b17a0e5de3e31a5420 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 18:24:13 +0200 Subject: [PATCH 13/58] wip : template for rows per warp --- ggml-metal.m | 7 ++++--- ggml-metal.metal | 54 +++++++++++++++++++++++++----------------------- 2 files changed, 32 insertions(+), 29 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index abb96d6ec6e44..d521df43ab302 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2252,14 +2252,15 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nwarps = 4; + const int64_t nwarps = 8; + const int64_t nhpw = 4; // heads per warp - const size_t smem = nwarps*(2*4*ne00 + 128)*(sizeof(float)/2); + const size_t smem = nwarps*(2*nhpw*ne00 + 128)*(sizeof(float)/2); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + 4*nwarps - 1)/(4*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + nhpw*nwarps - 1)/(nhpw*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index 6fdd7fdad4326..c9876c1033f1f 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1995,7 +1995,7 @@ typedef void (flash_attn_ext_f16_t)( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]); -template // head size +template // head size, rows per warp kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2036,9 +2036,10 @@ kernel void kernel_flash_attn_ext_f16( //const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; const uint nsg = ntg.x/N_SIMDWIDTH; // number of simdgroups + const uint tph = N_SIMDWIDTH/R; // threads per head const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]*(4*nsg) + 4*sgitg + tiisg/8; + const int64_t iq2 = tgpig[1]*(R*nsg) + R*sgitg + tiisg/tph; const int64_t iq1 = tgpig[0]; if (iq2 >= ne02) { @@ -2140,18 +2141,18 @@ kernel void kernel_flash_attn_ext_f16( const int64_t D4 = D/4; - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) ); - threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) + 4*D); - threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) + 2*4*D); - threadgroup half * ss = (threadgroup half *) (shared + sgitg*(2*4*D + 128) + 2*4*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 0*R*D); + threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 1*R*D); + threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 2*R*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(2*R*D + 128) + 2*R*D); - const uint tiih = tiisg%8; // thread index in head - const uint hiisg = tiisg/8; // head index in simdgroup + const uint tiih = tiisg%tph; // thread index in head + const uint hiisg = tiisg/tph; // head index in simdgroup - // load 8 heads from Q to shared memory - for (int64_t i = 0; i < D4/8; ++i) { - pq4[hiisg*D4 + 8*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[8*i + tiih]; - ps4[hiisg*D4 + 8*i + tiih] = 0.0h; + // load R heads from Q to shared memory + for (int64_t i = 0; i < D4/tph; ++i) { + pq4[hiisg*D4 + tph*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; + ps4[hiisg*D4 + tph*i + tiih] = 0.0h; } simdgroup_barrier(mem_flags::mem_threadgroup); @@ -2170,18 +2171,20 @@ kernel void kernel_flash_attn_ext_f16( half4 s4 = 0.0h; -#pragma unroll(D4/8) - for (int64_t i = 0; i < D4/8; ++i) { - s4 += pk4[8*i + tiih] * pq4[hiisg*D4 + 8*i + tiih]; + for (int64_t i = 0; i < D4/tph; ++i) { + s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih]; } - ss4[hiisg*8 + tiih] = s4; + ss4[hiisg*tph + tiih] = s4; simdgroup_barrier(mem_flags::mem_threadgroup); if (tiih == 0) { - s4 = ss4[8*hiisg + 0] + ss4[8*hiisg + 1] + ss4[8*hiisg + 2] + ss4[8*hiisg + 3] + - ss4[8*hiisg + 4] + ss4[8*hiisg + 5] + ss4[8*hiisg + 6] + ss4[8*hiisg + 7]; + s4 = 0.0h; + + for (int64_t i = 0; i < tph; ++i) { + s4 += ss4[hiisg*tph + i]; + } half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; @@ -2203,9 +2206,8 @@ kernel void kernel_flash_attn_ext_f16( const half ms = ss[2*hiisg + 0]; const half vs = ss[2*hiisg + 1]; -#pragma unroll(D4/8) - for (int64_t i = 0; i < D4/8; ++i) { - ps4[hiisg*D4 + 8*i + tiih] = ps4[hiisg*D4 + 8*i + tiih]*ms + pv4[8*i + tiih]*vs; + for (int64_t i = 0; i < D4/tph; ++i) { + ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs; } } @@ -2226,14 +2228,14 @@ kernel void kernel_flash_attn_ext_f16( device float4 * dst4 = (device float4 *) dst; - for (int64_t i = 0; i < D4/8; ++i) { - dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + 8*i + tiih] = (float4) ps4[hiisg*D4 + 8*i + tiih]; + for (int64_t i = 0; i < D4/tph; ++i) { + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih] = (float4) ps4[hiisg*D4 + tph*i + tiih]; } } -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 4>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 4>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 4>; kernel void kernel_cpy_f16_f16( device const half * src0, From 77d08f3272c62900b40d110bf0de7f4466675c71 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 21:04:15 +0200 Subject: [PATCH 14/58] metal : parallelize across KV size --- ggml-metal.m | 8 +-- ggml-metal.metal | 137 +++++++++++++++++------------------------------ 2 files changed, 52 insertions(+), 93 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index d521df43ab302..a60dd779a6f09 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2252,15 +2252,15 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nwarps = 8; - const int64_t nhpw = 4; // heads per warp + const int64_t nwarps = 16; + const int64_t nhptg = 4; // heads per threadgroup - const size_t smem = nwarps*(2*nhpw*ne00 + 128)*(sizeof(float)/2); + const size_t smem = (nhptg*ne00 + nwarps*(nhptg*ne00 + 32))*(sizeof(float)/2); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + nhpw*nwarps - 1)/(nhpw*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + nhptg - 1)/(nhptg), ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index c9876c1033f1f..539e26c91c34a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1995,7 +1995,7 @@ typedef void (flash_attn_ext_f16_t)( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]); -template // head size, rows per warp +template // head size, rows per threadgroup kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2031,15 +2031,11 @@ kernel void kernel_flash_attn_ext_f16( uint3 ntg[[threads_per_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - //const int64_t iq3 = tgpig[2]; - //const int64_t iq2 = tgpig[1]; - //const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; - - const uint nsg = ntg.x/N_SIMDWIDTH; // number of simdgroups - const uint tph = N_SIMDWIDTH/R; // threads per head + const uint nsg = ntg.y; // number of simdgroups + const uint tph = N_SIMDWIDTH/R; // threads per head const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]*(R*nsg) + R*sgitg + tiisg/tph; + const int64_t iq2 = tgpig[1]*R + tiisg/tph; const int64_t iq1 = tgpig[0]; if (iq2 >= ne02) { @@ -2073,94 +2069,30 @@ kernel void kernel_flash_attn_ext_f16( device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr; -// const int64_t D4 = D/4; -// -// // TODO: can we move this to the stack? -// threadgroup half4x4 * V16 = (threadgroup half4x4 *) (shared); -// -// // initialize with zeros -// for (int64_t d = 0; d < D4; ++d) { -// -// } -// -// threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 4*D); -// -// // load Q to shared memory -// for (int64_t d = 0; d < D4; ++d) { -// pq4[d] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; -// } -// -// half S = 0.0h; -// half M = -INFINITY; -// -// for (int64_t ic = 0; ic < ne11; ++ic) { -// const half mv = mp ? mp[ic] : 0.0h; -// if (mv == -INFINITY) { -// continue; -// } -// -// device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); -// device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); -// -// half4 s4 = 0.0h; -// -// for (int64_t d = 0; d < D4; ++d) { -// s4 += pk4[d] * pq4[d]; -// } -// -// half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; -// -// const half Mold = M; -// -// M = max(M, s); -// -// const half ms = exp(Mold - M); -// const half vs = exp(s - M); -// -// for (int64_t d = 0; d < D4; ++d) { -// V16[d] = V16[d]*ms + pv4[d]*vs; -// } -// -// S = S*ms + vs; -// } -// -// for (int64_t d = 0; d < D4; ++d) { -// V16[d] /= S; -// } -// -// // dst indices -// const int64_t i1 = iq1; -// const int64_t i2 = iq2; -// const int64_t i3 = iq3; -// -// device float4 * dst4 = (device float4 *) dst; -// -// for (int64_t d = 0; d < D4; ++d) { -// dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + d] = (float4) V16[d]; -// } - const int64_t D4 = D/4; - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 0*R*D); - threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 1*R*D); - threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 2*R*D); - threadgroup half * ss = (threadgroup half *) (shared + sgitg*(2*R*D + 128) + 2*R*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*R*D); + threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(R*D + 32) + 1*R*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(R*D + 32) + 2*R*D); const uint tiih = tiisg%tph; // thread index in head const uint hiisg = tiisg/tph; // head index in simdgroup // load R heads from Q to shared memory for (int64_t i = 0; i < D4/tph; ++i) { - pq4[hiisg*D4 + tph*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; + if (sgitg == 0) { + pq4[hiisg*D4 + tph*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; + } + ps4[hiisg*D4 + tph*i + tiih] = 0.0h; } - simdgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_threadgroup); half S = 0.0h; half M = -INFINITY; - for (int64_t ic = 0; ic < ne11; ++ic) { + for (int64_t ic = sgitg; ic < ne11; ic += nsg) { const half mv = mp ? mp[ic] : 0.0h; if (mv == -INFINITY) { continue; @@ -2175,18 +2107,18 @@ kernel void kernel_flash_attn_ext_f16( s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih]; } - ss4[hiisg*tph + tiih] = s4; + ss[hiisg*tph + tiih] = (s4.x + s4.y + s4.z + s4.w); simdgroup_barrier(mem_flags::mem_threadgroup); if (tiih == 0) { - s4 = 0.0h; + half s = 0.0h; for (int64_t i = 0; i < tph; ++i) { - s4 += ss4[hiisg*tph + i]; + s += ss[hiisg*tph + i]; } - half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; + s = s*scale + mv; const half Mold = M; @@ -2211,9 +2143,34 @@ kernel void kernel_flash_attn_ext_f16( } } - simdgroup_barrier(mem_flags::mem_threadgroup); - if (tiih == 0) { + ss[2*hiisg + 0] = S; + ss[2*hiisg + 1] = M; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // reduce the warps + if (sgitg == 0 && tiih == 0) { + for (int64_t sg = 1; sg < nsg; ++sg) { + const half S0 = S; + const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0]; + + const half M0 = M; + const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1]; + + M = max(M0, M1); + + const half ms0 = exp(M0 - M); + const half ms1 = exp(M1 - M); + + S = S0*ms0 + S1*ms1; + + for (int64_t i = 0; i < D4; ++i) { + ps4[hiisg*D4 + i] = ps4[hiisg*D4 + i]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + i]*ms1; + } + } + for (int64_t i = 0; i < D4; ++i) { ps4[hiisg*D4 + i] /= S; } @@ -2228,8 +2185,10 @@ kernel void kernel_flash_attn_ext_f16( device float4 * dst4 = (device float4 *) dst; - for (int64_t i = 0; i < D4/tph; ++i) { - dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih] = (float4) ps4[hiisg*D4 + tph*i + tiih]; + if (sgitg == 0) { + for (int64_t i = 0; i < D4/tph; ++i) { + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih] = (float4) ps4[hiisg*D4 + tph*i + tiih]; + } } } From 17720fad669eed6171ddf17184da5bab50adeb72 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 22:44:41 +0200 Subject: [PATCH 15/58] metal : parallel reduce across heads --- ggml-metal.m | 4 ++-- ggml-metal.metal | 32 ++++++++++++++++++++------------ 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index a60dd779a6f09..fdfb50d3d03f4 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2252,8 +2252,8 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nwarps = 16; - const int64_t nhptg = 4; // heads per threadgroup + const int64_t nwarps = 32; + const int64_t nhptg = 2; // heads per threadgroup const size_t smem = (nhptg*ne00 + nwarps*(nhptg*ne00 + 32))*(sizeof(float)/2); diff --git a/ggml-metal.metal b/ggml-metal.metal index 539e26c91c34a..919119c8d55af 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2103,6 +2103,7 @@ kernel void kernel_flash_attn_ext_f16( half4 s4 = 0.0h; +#pragma unroll for (int64_t i = 0; i < D4/tph; ++i) { s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih]; } @@ -2114,17 +2115,18 @@ kernel void kernel_flash_attn_ext_f16( if (tiih == 0) { half s = 0.0h; +#pragma unroll for (int64_t i = 0; i < tph; ++i) { s += ss[hiisg*tph + i]; } s = s*scale + mv; - const half Mold = M; + const half m = M; M = max(M, s); - const half ms = exp(Mold - M); + const half ms = exp(m - M); const half vs = exp(s - M); S = S*ms + vs; @@ -2138,6 +2140,7 @@ kernel void kernel_flash_attn_ext_f16( const half ms = ss[2*hiisg + 0]; const half vs = ss[2*hiisg + 1]; +#pragma unroll for (int64_t i = 0; i < D4/tph; ++i) { ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs; } @@ -2151,12 +2154,12 @@ kernel void kernel_flash_attn_ext_f16( threadgroup_barrier(mem_flags::mem_threadgroup); // reduce the warps - if (sgitg == 0 && tiih == 0) { + if (sgitg == 0) { for (int64_t sg = 1; sg < nsg; ++sg) { - const half S0 = S; + const half S0 = ss[ 2*hiisg + 0]; const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0]; - const half M0 = M; + const half M0 = ss[ 2*hiisg + 1]; const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1]; M = max(M0, M1); @@ -2166,13 +2169,18 @@ kernel void kernel_flash_attn_ext_f16( S = S0*ms0 + S1*ms1; - for (int64_t i = 0; i < D4; ++i) { - ps4[hiisg*D4 + i] = ps4[hiisg*D4 + i]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + i]*ms1; + if (tiih == 0) { + ss[2*hiisg + 0] = S; + ss[2*hiisg + 1] = M; + } + + for (int64_t i = 0; i < D4/tph; ++i) { + ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih]*ms1; } } - for (int64_t i = 0; i < D4; ++i) { - ps4[hiisg*D4 + i] /= S; + for (int64_t i = 0; i < D4/tph; ++i) { + ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]/S; } } @@ -2192,9 +2200,9 @@ kernel void kernel_flash_attn_ext_f16( } } -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 4>; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 4>; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 4>; +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 2>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 2>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 2>; kernel void kernel_cpy_f16_f16( device const half * src0, From 6374bc5779784de48fd79351942f8b53589eff7e Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Tue, 23 Jan 2024 16:42:53 -0500 Subject: [PATCH 16/58] cuda: port metal version flash_attn_ext --- ggml-cuda.cu | 305 ++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 304 insertions(+), 1 deletion(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 0eeee748415dd..940ffbfc8d87d 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -937,6 +937,7 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr if (lane_id == 0) { s_sum[warp_id] = tmp; } + __syncthreads(); tmp = s_sum[lane_id]; tmp = warp_reduce_sum(tmp); @@ -6106,6 +6107,211 @@ static __global__ void flash_attn_f32( } } +struct __align__(8) half4 { + half x; + half y; + half z; + half w; +}; + +// based on metal version +template // head size, rows per block +static __global__ void flash_attn_ext_f16( + const char* __restrict__ q, + const char* __restrict__ k, + const char* __restrict__ v, + const char* __restrict__ mask, + float* __restrict__ kqv, + float scale, + int ne00, + int ne01, + int ne02, + int ne03, + int ne10, + int ne11, + int ne12, + int ne13, + int ne31, + int nb31, + int nb01, + int nb02, + int nb03, + int nb11, + int nb12, + int nb13, + int ne0, + int ne1, + int ne2, + int ne3) { + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + + const int nwraps = blockDim.y; // number of warps + const int tph = WARP_SIZE / R; // threads per head + const int iq3 = blockIdx.z; + const int iq2 = blockIdx.y * R + lane_id / tph; + const int iq1 = blockIdx.x; + + if(iq2 >= ne02) { + return; + } + + // broadcast + const int rk2 = ne02 / ne12; + const int rk3 = ne03 / ne13; + // assume the same K and V shape + // const int rv2 = ne02 / ne12; + // const int rv3 = ne03 / ne13; + + // kv indices + const int ik2 = iq2 / rk2; + const int ik3 = iq3 / rk3; + const int iv2 = iq2 / rv2; + const int iv3 = iq3 / rv3; + + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; + + const float * mp = mask ? mask + (ir % ne31)*nb31 : nullptr; + + extern __shared__ char shmem__[]; + + half4* pq4 = (half4*)shmem__; + half4* ps4 = (half4*)(shmem__ + warp_id * (R * D + 32) + 1*R*D); + half* ss = (half *)(shmem__ + warp_id * (R * D + 32) + 2*R*D); + + const int tiih = lane_id % tph; // thread index in head + const int hiisg = lane_id / tph; // head index in warp + + const int D4 = D/4; + + // load R heads from Q to shared memory + for (int64_t i = 0; i < D4/tph; ++i) { + if (warp_id == 0) { + pq4[hiisg*D4 + tph*i + tiih] = (const half4*)((const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03))[tph*i + tiih]; + } + + ps4[hiisg*D4 + tph*i + tiih] = 0.0h; + } + __syncthreads(); + + half S = 0.0h; + half M = -INFINITY; + + for (int64_t ic = warp_id; ic < ne11; ic += nwraps) { + const half mv = mp ? mp[ic] : 0.0h; + if (mv == -INFINITY) { + continue; + } + + const half4 * pk4 = (const half4 *) ((char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); + const half4 * pv4 = (const half4 *) ((char *) v + (ic*nb11 + iv2*nb12 + iv3*nb13)); // assumes V same shape of K + + half4 s4 = 0.0h; + +#pragma unroll + for (int i = 0; i < D4/tph; ++i) { + s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih]; + } + + ss[hiisg*tph + tiih] = (s4.x + s4.y + s4.z + s4.w); + + __syncthreads(); + + if (tiih == 0) { + half s = 0.0h; + +#pragma unroll + for (int i = 0; i < tph; ++i) { + s += ss[hiisg*tph + i]; + } + + s = s*scale + mv; + + const half m = M; + + M = max(M, s); + + const half ms = exp(m - M); + const half vs = exp(s - M); + + S = S*ms + vs; + + ss[2*hiisg + 0] = ms; + ss[2*hiisg + 1] = vs; + } + + __syncthreads(); + + const half ms = ss[2*hiisg + 0]; + const half vs = ss[2*hiisg + 1]; + +#pragma unroll + for (int i = 0; i < D4/tph; ++i) { + ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs; + } + } + + if (tiih == 0) { + ss[2*hiisg + 0] = S; + ss[2*hiisg + 1] = M; + } + + __syncthreads(); + + // reduce the warps + if (warp_id == 0) { + for (int sg = 1; sg < nwraps; ++sg) { + const half S0 = ss[ 2*hiisg + 0]; + const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0]; + + const half M0 = ss[ 2*hiisg + 1]; + const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1]; + + M = max(M0, M1); + + const half ms0 = exp(M0 - M); + const half ms1 = exp(M1 - M); + + S = S0*ms0 + S1*ms1; + + if (tiih == 0) { + ss[2*hiisg + 0] = S; + ss[2*hiisg + 1] = M; + } + + for (int i = 0; i < D4/tph; ++i) { + ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih]*ms1; + } + } + + for (int i = 0; i < D4/tph; ++i) { + ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]/S; + } + } + + __syncthreads(); + + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + float4 * dst4 = (float4 *) kqv; + + if (warp_id == 0) { + for (int i = 0; i < D4/tph; ++i) { + float4 dst_ = + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih]; + half4 src_ = ps4[hiisg*D4 + tph*i + tiih]; + dst_.x = __half2float(src_.x); + dst_.y = __half2float(src_.y); + dst_.z = __half2float(src_.z); + dst_.w = __half2float(src_.w); + } + } +} + + template static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { @@ -10071,6 +10277,98 @@ inline void ggml_cuda_flash_attn(const ggml_tensor * Q, const ggml_tensor * K, c KQ_scale, d_head, sequence_length, num_heads, main_stream); } + +inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, const ggml_tensor * mask, ggml_tensor * KQV) { + GGML_ASSERT(Q->type == GGML_TYPE_F16); + GGML_ASSERT(K->type == GGML_TYPE_F16); + GGML_ASSERT(V->type == GGML_TYPE_F16); + GGML_ASSERT(mask->type == GGML_TYPE_F32); + GGML_ASSERT(KQV->type == GGML_TYPE_F32); + + GGML_ASSERT(Q->backend == GGML_BACKEND_GPU); + GGML_ASSERT(K->backend == GGML_BACKEND_GPU); + GGML_ASSERT(V->backend == GGML_BACKEND_GPU); + GGML_ASSERT(mask->backend == GGML_BACKEND_GPU); + GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU); + + ggml_cuda_set_device(g_main_device); + const cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; + + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) Q->extra; + ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) K->extra; + ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) V->extra; + ggml_tensor_extra_gpu * src3_extra = (ggml_tensor_extra_gpu *) mask->extra; + ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) KQV->extra; + + float scale; + memcpy(&scale, KQV->op_params, sizeof(float)); + + const int nwarps = 32; + const int nhpw = 2; // heads per warp + + dim3 blocks_num(Q->ne[1], (Q->ne[2] + nhpw - 1)/(nhpw), Q->ne[3]); + dim3 block_dim(32, nwarps, 1); + + int shmem = (nhpw*Q->ne[0] + nwarps*(nhpw*Q->ne[0] + 32))*(sizeof(float)/2); + + switch (Q->ne[0]) + { + case 64: + flash_attn_ext_f16<64, 2> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + (const char *) src3_extra->data_device[g_main_device], // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask->ne[1], mask->nb[1], + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 80: + flash_attn_ext_f16<80, 2> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + (const char *) src3_extra->data_device[g_main_device], // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask->ne[1], mask->nb[1], + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 128: + flash_attn_ext_f16<128, 2> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + (const char *) src3_extra->data_device[g_main_device], // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask->ne[1], mask->nb[1], + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + default: + break; + } +} + static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale); } @@ -10341,6 +10639,8 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st break; case GGML_OP_FLASH_ATTN: break; + case GGML_OP_FLASH_ATTN_EXT: + break; default: return false; } @@ -10357,7 +10657,9 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st } if(tensor->op == GGML_OP_FLASH_ATTN) { ggml_cuda_flash_attn(tensor->src[0], tensor->src[1], tensor->src[2], tensor); - } else { + } else if(tensor->op == GGML_OP_FLASH_ATTN_EXT) { + ggml_cuda_flash_attn_ext(tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); + } else { func(tensor->src[0], tensor->src[1], tensor); } return true; @@ -11175,6 +11477,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_UPSCALE: case GGML_OP_PAD: case GGML_OP_LEAKY_RELU: + case GGML_OP_FLASH_ATTN_EXT: return true; default: return false; From 641682149972e965c6b70525f8fa829496ce4c89 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Wed, 24 Jan 2024 10:57:05 -0500 Subject: [PATCH 17/58] fix equivalent fp16 math functions, compiler error 'undefined' --- ggml-cuda.cu | 86 +++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 61 insertions(+), 25 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 940ffbfc8d87d..9d2b99ac993e8 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6114,6 +6114,42 @@ struct __align__(8) half4 { half w; }; +__device__ half4 make_half4(half x) { + half4 t; + t.x = x; t.y = x; t.z = x; t.w = x; + return t; +} + +__device__ half4 __h4fma(half4 a, half b, half4 c) { + half4 t; + t.x = __hfma(a.x, b, c.x); t.y = __hfma(a.y, b, c.y); t.z = __hfma(a.z, b, c.z); t.w = __hfma(a.w, b, c.w); + return t; +} + +__device__ half4 __h4fma(half4 a, half4 b, half4 c) { + half4 t; + t.x = __hfma(a.x, b.x, c.x); t.y = __hfma(a.y, b.y, c.y); t.z = __hfma(a.z, b.z, c.z); t.w = __hfma(a.w, b.w, c.w); + return t; +} + +__device__ half4 __h4mul(half4 a, half b) { + half4 t; + t.x = __hmul(a.x, b); t.y = __hmul(a.y, b); t.z =__hmul(a.z, b); t.w =__hmul(a.w, b); + return t; +} + +__device__ half4 __h4mul(half4 a, half4 b) { + half4 t; + t.x = __hmul(a.x, b.x); t.y = __hmul(a.y, b.y); t.z =__hmul(a.z, b.z); t.w =__hmul(a.w, b.w); + return t; +} + +__device__ half4 __h4div(half4 a, half b) { + half4 t; + t.x = __hdiv(a.x, b); t.y = __hdiv(a.y, b); t.z =__hdiv(a.z, b); t.w =__hdiv(a.w, b); + return t; +} + // based on metal version template // head size, rows per block static __global__ void flash_attn_ext_f16( @@ -6166,12 +6202,12 @@ static __global__ void flash_attn_ext_f16( // kv indices const int ik2 = iq2 / rk2; const int ik3 = iq3 / rk3; - const int iv2 = iq2 / rv2; - const int iv3 = iq3 / rv3; + // const int iv2 = iq2 / rv2; + // const int iv3 = iq3 / rv3; const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - const float * mp = mask ? mask + (ir % ne31)*nb31 : nullptr; + const float * mp = mask ? (const float*)(mask + (ir % ne31)*nb31) : nullptr; extern __shared__ char shmem__[]; @@ -6187,30 +6223,30 @@ static __global__ void flash_attn_ext_f16( // load R heads from Q to shared memory for (int64_t i = 0; i < D4/tph; ++i) { if (warp_id == 0) { - pq4[hiisg*D4 + tph*i + tiih] = (const half4*)((const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03))[tph*i + tiih]; + pq4[hiisg*D4 + tph*i + tiih] = ((half4*)(q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; } - ps4[hiisg*D4 + tph*i + tiih] = 0.0h; + ps4[hiisg*D4 + tph*i + tiih] = make_half4(0.0); } __syncthreads(); - half S = 0.0h; - half M = -INFINITY; + half S(0.0); + half M(-INFINITY); for (int64_t ic = warp_id; ic < ne11; ic += nwraps) { - const half mv = mp ? mp[ic] : 0.0h; - if (mv == -INFINITY) { + const half mv = mp ? mp[ic] : 0.0; + if (__hisinf(mv) == -1) { // mv == -INFINITY continue; } const half4 * pk4 = (const half4 *) ((char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); - const half4 * pv4 = (const half4 *) ((char *) v + (ic*nb11 + iv2*nb12 + iv3*nb13)); // assumes V same shape of K + const half4 * pv4 = (const half4 *) ((char *) v + (ic*nb11 + ik2*nb12 + ik3*nb13)); // assumes V same shape of K - half4 s4 = 0.0h; + half4 s4 = make_half4(0.0); #pragma unroll for (int i = 0; i < D4/tph; ++i) { - s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih]; + s4 = __h4fma(pq4[hiisg*D4 + tph*i + tiih], pk4[tph*i + tiih], s4); } ss[hiisg*tph + tiih] = (s4.x + s4.y + s4.z + s4.w); @@ -6218,23 +6254,23 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); if (tiih == 0) { - half s = 0.0h; + half s = 0.0; #pragma unroll for (int i = 0; i < tph; ++i) { s += ss[hiisg*tph + i]; } - s = s*scale + mv; + s = __hfma(s, __float2half(scale), mv); // s*scale + mv const half m = M; - M = max(M, s); + M = __hmax(M, s); - const half ms = exp(m - M); - const half vs = exp(s - M); + const half ms = hexp(m - M); + const half vs = hexp(s - M); - S = S*ms + vs; + S = __hfma(S, ms, vs); ss[2*hiisg + 0] = ms; ss[2*hiisg + 1] = vs; @@ -6247,7 +6283,7 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int i = 0; i < D4/tph; ++i) { - ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs; + ps4[hiisg*D4 + tph*i + tiih] = __h4fma(ps4[hiisg*D4 + tph*i + tiih], ms, __h4mul(pv4[tph*i + tiih], vs)); } } @@ -6267,12 +6303,12 @@ static __global__ void flash_attn_ext_f16( const half M0 = ss[ 2*hiisg + 1]; const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1]; - M = max(M0, M1); + M = __hmax(M0, M1); - const half ms0 = exp(M0 - M); - const half ms1 = exp(M1 - M); + const half ms0 = hexp(M0 - M); + const half ms1 = hexp(M1 - M); - S = S0*ms0 + S1*ms1; + S = __hfma(S0, ms0, __hmul(S1, ms1)); if (tiih == 0) { ss[2*hiisg + 0] = S; @@ -6280,12 +6316,12 @@ static __global__ void flash_attn_ext_f16( } for (int i = 0; i < D4/tph; ++i) { - ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih]*ms1; + ps4[hiisg*D4 + tph*i + tiih] = __h4fma(ps4[hiisg*D4 + tph*i + tiih], ms0, __h4mul(ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih], ms1)); } } for (int i = 0; i < D4/tph; ++i) { - ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]/S; + ps4[hiisg*D4 + tph*i + tiih] = __h4div(ps4[hiisg*D4 + tph*i + tiih], S); } } From 972c2adc15b5d61c2b3f267989a3185d2a99ce46 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Wed, 24 Jan 2024 16:41:57 -0500 Subject: [PATCH 18/58] use half2 instead half4 --- ggml-cuda.cu | 197 ++++++++++++++++++++------------------------------- 1 file changed, 77 insertions(+), 120 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 9d2b99ac993e8..e9657dd88f931 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -5992,7 +5992,7 @@ static __global__ void im2col_f32_f16( #define CUDA_FLASH_ATTENTION_BLOCK_SIZE 256 -template +template static __global__ void flash_attn_f32( const float* __restrict__ q, const float* __restrict__ k, @@ -6004,9 +6004,9 @@ static __global__ void flash_attn_f32( const int head_size = head_dim * seq_len; const int s = blockIdx.x % seq_len; - extern __shared__ char shmem__[]; - float* S = (float*)shmem__; - float* warp_data = (float*)(shmem__ + seq_len * sizeof(float)); + extern __shared__ char flash_attn_shmem_f32[]; + float* S = (float*)flash_attn_shmem_f32; + float* warp_data = (float*)(flash_attn_shmem_f32 + seq_len * sizeof(float)); // QK^T #pragma unroll @@ -6019,11 +6019,11 @@ static __global__ void flash_attn_f32( const int key_offset = is * head_dim + head * head_size; const int query_offset = s * head_dim + head * head_size; - S[is] = 0.0f; + float tmp = 0.0f; for(int d = 0; d < head_dim; d++) { - S[is] += k[key_offset + d] * q[query_offset + d]; + tmp += k[key_offset + d] * q[query_offset + d]; } - S[is] *= kq_scale; + S[is] = tmp * kq_scale; } __syncthreads(); @@ -6060,9 +6060,9 @@ static __global__ void flash_attn_f32( if(is >= seq_len) { break; } - - S[is] = expf(S[is] - max_val); - sum += S[is]; + float tmp = expf(S[is] - max_val); + sum += tmp; + S[is] = tmp; } __syncthreads(); @@ -6091,7 +6091,12 @@ static __global__ void flash_attn_f32( __syncthreads(); // softmax(QK^T)V - for (int d = threadIdx.x; d < head_dim; d += block_size) { + #pragma unroll + for (int d0 = threadIdx.x; d0 < k_head_dim; d0 += block_size) { + const int d = threadIdx.x + d0; + if(d >= head_dim) { + break; + } const int dst_index = d + s * head_dim + head * head_size; const int value_offset = d * seq_len + head * head_size; @@ -6107,51 +6112,8 @@ static __global__ void flash_attn_f32( } } -struct __align__(8) half4 { - half x; - half y; - half z; - half w; -}; - -__device__ half4 make_half4(half x) { - half4 t; - t.x = x; t.y = x; t.z = x; t.w = x; - return t; -} - -__device__ half4 __h4fma(half4 a, half b, half4 c) { - half4 t; - t.x = __hfma(a.x, b, c.x); t.y = __hfma(a.y, b, c.y); t.z = __hfma(a.z, b, c.z); t.w = __hfma(a.w, b, c.w); - return t; -} - -__device__ half4 __h4fma(half4 a, half4 b, half4 c) { - half4 t; - t.x = __hfma(a.x, b.x, c.x); t.y = __hfma(a.y, b.y, c.y); t.z = __hfma(a.z, b.z, c.z); t.w = __hfma(a.w, b.w, c.w); - return t; -} - -__device__ half4 __h4mul(half4 a, half b) { - half4 t; - t.x = __hmul(a.x, b); t.y = __hmul(a.y, b); t.z =__hmul(a.z, b); t.w =__hmul(a.w, b); - return t; -} - -__device__ half4 __h4mul(half4 a, half4 b) { - half4 t; - t.x = __hmul(a.x, b.x); t.y = __hmul(a.y, b.y); t.z =__hmul(a.z, b.z); t.w =__hmul(a.w, b.w); - return t; -} - -__device__ half4 __h4div(half4 a, half b) { - half4 t; - t.x = __hdiv(a.x, b); t.y = __hdiv(a.y, b); t.z =__hdiv(a.z, b); t.w =__hdiv(a.w, b); - return t; -} - // based on metal version -template // head size, rows per block +template // D head size, R rows per block static __global__ void flash_attn_ext_f16( const char* __restrict__ q, const char* __restrict__ k, @@ -6205,91 +6167,93 @@ static __global__ void flash_attn_ext_f16( // const int iv2 = iq2 / rv2; // const int iv3 = iq3 / rv3; + const half2 scale_h = __half2half2(__float2half(scale)); + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; const float * mp = mask ? (const float*)(mask + (ir % ne31)*nb31) : nullptr; - extern __shared__ char shmem__[]; + extern __shared__ char data_flash_attn_shmem[]; - half4* pq4 = (half4*)shmem__; - half4* ps4 = (half4*)(shmem__ + warp_id * (R * D + 32) + 1*R*D); - half* ss = (half *)(shmem__ + warp_id * (R * D + 32) + 2*R*D); + half2* pq2 = (half2*)data_flash_attn_shmem; + half2* ps2 = (half2*)(data_flash_attn_shmem + warp_id * (R*D + 32) + 1*R*D); + half2* ss = (half2*)(data_flash_attn_shmem + warp_id * (R*D + 32) + 2*R*D); - const int tiih = lane_id % tph; // thread index in head - const int hiisg = lane_id / tph; // head index in warp + const int tiih = lane_id % tph; // thread index in head + const int hiiw = lane_id / tph; // head index in warp - const int D4 = D/4; + const int D2 = D / 2; // number of half2 to store head_dim row // load R heads from Q to shared memory - for (int64_t i = 0; i < D4/tph; ++i) { + for (int i = 0; i < D2/tph; ++i) { if (warp_id == 0) { - pq4[hiisg*D4 + tph*i + tiih] = ((half4*)(q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; + pq2[hiiw*D2 + tph*i + tiih] = ((half2*)(q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; } - ps4[hiisg*D4 + tph*i + tiih] = make_half4(0.0); + ps2[hiiw*D2 + tph*i + tiih] = make_half2(0.0, 0.0); } __syncthreads(); - half S(0.0); - half M(-INFINITY); + half2 S = make_half2(0.0, 0.0); + half2 M = make_half2(-INFINITY, -INFINITY); for (int64_t ic = warp_id; ic < ne11; ic += nwraps) { - const half mv = mp ? mp[ic] : 0.0; - if (__hisinf(mv) == -1) { // mv == -INFINITY + const half2 mv = make_half2(mp ? mp[ic] : 0.0, 0.0); + if (__hisinf(mv.x) == -1) { // mv == -INFINITY continue; } - const half4 * pk4 = (const half4 *) ((char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); - const half4 * pv4 = (const half4 *) ((char *) v + (ic*nb11 + ik2*nb12 + ik3*nb13)); // assumes V same shape of K + half2 * pk2 = (half2 *) ((char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); + half2 * pv2 = (half2 *) ((char *) v + (ic*nb11 + ik2*nb12 + ik3*nb13)); // assumes V same shape of K - half4 s4 = make_half4(0.0); + half2 s2 = make_half2(0.0, 0.0); #pragma unroll - for (int i = 0; i < D4/tph; ++i) { - s4 = __h4fma(pq4[hiisg*D4 + tph*i + tiih], pk4[tph*i + tiih], s4); + for (int i = 0; i < D2/tph; ++i) { + s2 = pq2[hiiw*D2 + tph*i + tiih] * pk2[tph*i + tiih] + s2; } - ss[hiisg*tph + tiih] = (s4.x + s4.y + s4.z + s4.w); + ss[hiiw*tph + tiih] = __half2half2(s2.x + s2.y); __syncthreads(); if (tiih == 0) { - half s = 0.0; + half2 s = make_half2(0.0, 0.0); #pragma unroll for (int i = 0; i < tph; ++i) { - s += ss[hiisg*tph + i]; + s += ss[hiiw*tph + i]; } - s = __hfma(s, __float2half(scale), mv); // s*scale + mv + s = s * scale_h + mv; // s*scale + mv - const half m = M; + half2 m = M; - M = __hmax(M, s); + M = __hmax2(M, s); - const half ms = hexp(m - M); - const half vs = hexp(s - M); + half2 ms = h2exp(m - M); + half2 vs = h2exp(s - M); - S = __hfma(S, ms, vs); + S = S * ms + vs; - ss[2*hiisg + 0] = ms; - ss[2*hiisg + 1] = vs; + ss[2*hiiw + 0] = ms; + ss[2*hiiw + 1] = vs; } __syncthreads(); - const half ms = ss[2*hiisg + 0]; - const half vs = ss[2*hiisg + 1]; + half2 ms = ss[2*hiiw + 0]; + half2 vs = ss[2*hiiw + 1]; #pragma unroll - for (int i = 0; i < D4/tph; ++i) { - ps4[hiisg*D4 + tph*i + tiih] = __h4fma(ps4[hiisg*D4 + tph*i + tiih], ms, __h4mul(pv4[tph*i + tiih], vs)); + for (int i = 0; i < D2/tph; ++i) { + ps2[hiiw*D2 + tph*i + tiih] = ps2[hiiw*D2 + tph*i + tiih] * ms + pv2[tph*i + tiih] * vs; } } if (tiih == 0) { - ss[2*hiisg + 0] = S; - ss[2*hiisg + 1] = M; + ss[2*hiiw + 0] = S; + ss[2*hiiw + 1] = M; } __syncthreads(); @@ -6297,31 +6261,31 @@ static __global__ void flash_attn_ext_f16( // reduce the warps if (warp_id == 0) { for (int sg = 1; sg < nwraps; ++sg) { - const half S0 = ss[ 2*hiisg + 0]; - const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0]; + half2 S0 = ss[ 2*hiiw + 0]; + half2 S1 = ss[sg*(R*D + 32) + 2*hiiw + 0]; - const half M0 = ss[ 2*hiisg + 1]; - const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1]; + half2 M0 = ss[ 2*hiiw + 1]; + half2 M1 = ss[sg*(R*D + 32) + 2*hiiw + 1]; - M = __hmax(M0, M1); + M = __hmax2(M0, M1); - const half ms0 = hexp(M0 - M); - const half ms1 = hexp(M1 - M); + half2 ms0 = h2exp(M0 - M); + half2 ms1 = h2exp(M1 - M); - S = __hfma(S0, ms0, __hmul(S1, ms1)); + S = S0 * ms0 + S1 * ms1; if (tiih == 0) { - ss[2*hiisg + 0] = S; - ss[2*hiisg + 1] = M; + ss[2*hiiw + 0] = S; + ss[2*hiiw + 1] = M; } - for (int i = 0; i < D4/tph; ++i) { - ps4[hiisg*D4 + tph*i + tiih] = __h4fma(ps4[hiisg*D4 + tph*i + tiih], ms0, __h4mul(ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih], ms1)); + for (int i = 0; i < D2/tph; ++i) { + ps2[hiiw*D2 + tph*i + tiih] = ps2[hiiw*D2 + tph*i + tiih] * ms0 + ps2[sg*(R*D + 32)/4 + hiiw*D2 + tph*i + tiih] * ms1; } } - for (int i = 0; i < D4/tph; ++i) { - ps4[hiisg*D4 + tph*i + tiih] = __h4div(ps4[hiisg*D4 + tph*i + tiih], S); + for (int i = 0; i < D2/tph; ++i) { + ps2[hiiw*D2 + tph*i + tiih] = __h2div(ps2[hiiw*D2 + tph*i + tiih], S); } } @@ -6332,17 +6296,10 @@ static __global__ void flash_attn_ext_f16( const int i2 = iq2; const int i3 = iq3; - float4 * dst4 = (float4 *) kqv; - + float2 * dst2 = (float2 *) kqv; if (warp_id == 0) { - for (int i = 0; i < D4/tph; ++i) { - float4 dst_ = - dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih]; - half4 src_ = ps4[hiisg*D4 + tph*i + tiih]; - dst_.x = __half2float(src_.x); - dst_.y = __half2float(src_.y); - dst_.z = __half2float(src_.z); - dst_.w = __half2float(src_.w); + for (int i = 0; i < D2/tph; ++i) { + dst2[(i3*ne2*ne1 + i2 + i1*ne1)*D2 + tph*i + tiih] = __half22float2(ps2[hiiw*D2 + tph*i + tiih]); } } } @@ -7741,7 +7698,7 @@ static void im2col_f32_f16_cuda(const float* x, half* dst, static void flash_attn_f32_cuda(const float* q, const float* k,const float* v, float* dst, float kq_scale, const int d_head, const int seq_len, const int num_heads, cudaStream_t stream) { int sram_memory_size = seq_len*sizeof(float) + WARP_SIZE * sizeof(float); int num_blocks = num_heads * seq_len; - flash_attn_f32<<>>( + flash_attn_f32<<>>( q, k, v, dst, kq_scale, d_head, seq_len, num_heads); } @@ -10342,11 +10299,11 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * const int nwarps = 32; const int nhpw = 2; // heads per warp - dim3 blocks_num(Q->ne[1], (Q->ne[2] + nhpw - 1)/(nhpw), Q->ne[3]); - dim3 block_dim(32, nwarps, 1); - - int shmem = (nhpw*Q->ne[0] + nwarps*(nhpw*Q->ne[0] + 32))*(sizeof(float)/2); + dim3 blocks_num(Q->ne[1], (Q->ne[2] + nhpw - 1) / nhpw, Q->ne[3]); + dim3 block_dim(32 * nwarps, 1, 1); + int shmem = (nhpw*Q->ne[0]*2 + nwarps*(nhpw*Q->ne[0] + 32)) * (sizeof(float)/2); + printf("shared memory: %d bytes [%i, %i, %i]\n\n", shmem, Q->ne[0], Q->ne[1], Q->ne[2]); switch (Q->ne[0]) { case 64: From 0fc36d872c4644cb685c6f539c781d0841becaf7 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Wed, 24 Jan 2024 16:45:30 -0500 Subject: [PATCH 19/58] match to metal impl --- ggml-cuda.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index e9657dd88f931..b7ebfcc57835f 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6225,7 +6225,7 @@ static __global__ void flash_attn_ext_f16( s += ss[hiiw*tph + i]; } - s = s * scale_h + mv; // s*scale + mv + s = s*scale_h + mv; // s*scale + mv half2 m = M; @@ -6234,7 +6234,7 @@ static __global__ void flash_attn_ext_f16( half2 ms = h2exp(m - M); half2 vs = h2exp(s - M); - S = S * ms + vs; + S = S*ms + vs; ss[2*hiiw + 0] = ms; ss[2*hiiw + 1] = vs; @@ -6247,7 +6247,7 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int i = 0; i < D2/tph; ++i) { - ps2[hiiw*D2 + tph*i + tiih] = ps2[hiiw*D2 + tph*i + tiih] * ms + pv2[tph*i + tiih] * vs; + ps2[hiiw*D2 + tph*i + tiih] = ps2[hiiw*D2 + tph*i + tiih]*ms + pv2[tph*i + tiih]*vs; } } @@ -6272,7 +6272,7 @@ static __global__ void flash_attn_ext_f16( half2 ms0 = h2exp(M0 - M); half2 ms1 = h2exp(M1 - M); - S = S0 * ms0 + S1 * ms1; + S = S0*ms0 + S1*ms1; if (tiih == 0) { ss[2*hiiw + 0] = S; @@ -6280,7 +6280,7 @@ static __global__ void flash_attn_ext_f16( } for (int i = 0; i < D2/tph; ++i) { - ps2[hiiw*D2 + tph*i + tiih] = ps2[hiiw*D2 + tph*i + tiih] * ms0 + ps2[sg*(R*D + 32)/4 + hiiw*D2 + tph*i + tiih] * ms1; + ps2[hiiw*D2 + tph*i + tiih] = ps2[hiiw*D2 + tph*i + tiih]*ms0 + ps2[sg*(R*D + 32)/4 + hiiw*D2 + tph*i + tiih]*ms1; } } From 1446a12b29f422a0c0040e62c16715a3fb7ce1cb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 23 Jan 2024 18:27:54 +0200 Subject: [PATCH 20/58] metal : efficient flash_attn_f16 implementation --- ggml-metal.m | 14 +- ggml-metal.metal | 279 +++++++++++++++++++++++-------------- tests/test-backend-ops.cpp | 6 +- 3 files changed, 188 insertions(+), 111 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index fdfb50d3d03f4..7b161c69d5801 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2183,6 +2183,7 @@ static bool ggml_metal_graph_compute( struct ggml_tensor * src3 = gf->nodes[i]->src[3]; GGML_ASSERT(ggml_are_same_shape(src1, src2)); + GGML_ASSERT(src3); size_t offs_src2 = 0; size_t offs_src3 = 0; @@ -2252,15 +2253,20 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nwarps = 32; - const int64_t nhptg = 2; // heads per threadgroup + // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) + const int64_t nsg = ne01 < 4 ? 4 : 2; // simdgroups per threadgroup (a.k.a. warps) - const size_t smem = (nhptg*ne00 + nwarps*(nhptg*ne00 + 32))*(sizeof(float)/2); + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values) + //const size_t smem = nqptg*(nhptg*ne00 + nsg*(nhptg*ne00 + 256))*(sizeof(float)/2); + const size_t smem = nqptg*(ne00 + nsg*(ne00 + 1*ncpsg))*(sizeof(float)/2); + + //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + nhptg - 1)/(nhptg), ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index 919119c8d55af..9b6ceec4e1066 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1995,7 +1995,7 @@ typedef void (flash_attn_ext_f16_t)( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]); -template // head size, rows per threadgroup +template // head size, heads per threadgroup, queries per threadgroup kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2031,178 +2031,247 @@ kernel void kernel_flash_attn_ext_f16( uint3 ntg[[threads_per_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - const uint nsg = ntg.y; // number of simdgroups - const uint tph = N_SIMDWIDTH/R; // threads per head + const uint nsg = ntg.y; // number of simdgroups const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]*R + tiisg/tph; - const int64_t iq1 = tgpig[0]; + const int64_t iq2 = tgpig[1]; + const int64_t iq1 = tgpig[0]*Q; if (iq2 >= ne02) { return; } - // assume K and V are same shape - const int64_t ne22 = ne12; - const int64_t ne23 = ne13; + const int64_t D4 = D/4; + const int64_t N4 = N_SIMDWIDTH; + const int64_t L4 = (D4 + N4 - 1)/N4; + const int64_t D8 = D/8; + + const int64_t T = D + nsg*(D + 1*C); // shared memory size per query in half + const int64_t T4 = T/4; // shared memory size per query in half4 + + threadgroup half * pq = (threadgroup half *) (shared + 0*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*D); + threadgroup half * ps = (threadgroup half *) (shared + sgitg*(D + 1*C) + 1*D); + threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(D + 1*C) + 1*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(D + 1*C) + 2*D); + + for (int64_t i = 0; i < L4; ++i) { + // load heads from Q to shared memory + for (int64_t j = sgitg; j < Q; j += nsg) { + if (iq1 + j < ne01) { + pq4[j*T4 + N4*i + tiisg] = ((device const half4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)))[N4*i + tiisg]; + } else { + pq4[j*T4 + N4*i + tiisg] = 0.0h; + } + } - const uint64_t nb21 = nb11; - const uint64_t nb22 = nb12; - const uint64_t nb23 = nb13; + // zero out shared memory + for (int64_t j = 0; j < Q; ++j) { + ps4[j*T4 + N4*i + tiisg] = 0.0h; + } + } - // broadcast - const int64_t rk2 = ne02/ne12; - const int64_t rk3 = ne03/ne13; + if (tiisg < C) { + for (int64_t j = 0; j < Q; ++j) { + ss[j*T + 0 + tiisg] = 0.0h; + } + } - const int64_t rv2 = ne02/ne22; - const int64_t rv3 = ne03/ne23; + threadgroup_barrier(mem_flags::mem_threadgroup); - // k indices - const int64_t ik2 = iq2 / rk2; - const int64_t ik3 = iq3 / rk3; + { + half S[Q] = { 0.0h }; + half M[Q] = { -INFINITY }; - // v indices - const int64_t iv2 = iq2 / rv2; - const int64_t iv3 = iq3 / rv3; + // assume K and V are same shape + const int64_t ne22 = ne12; + const int64_t ne23 = ne13; - const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; + const uint64_t nb21 = nb11; + const uint64_t nb22 = nb12; + const uint64_t nb23 = nb13; - device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr; + // broadcast + const int64_t rk2 = ne02/ne12; + const int64_t rk3 = ne03/ne13; - const int64_t D4 = D/4; + const int64_t rv2 = ne02/ne22; + const int64_t rv3 = ne03/ne23; - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*R*D); - threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(R*D + 32) + 1*R*D); - threadgroup half * ss = (threadgroup half *) (shared + sgitg*(R*D + 32) + 2*R*D); + // k indices + const int64_t ik2 = iq2 / rk2; + const int64_t ik3 = iq3 / rk3; - const uint tiih = tiisg%tph; // thread index in head - const uint hiisg = tiisg/tph; // head index in simdgroup + // v indices + const int64_t iv2 = iq2 / rv2; + const int64_t iv3 = iq3 / rv3; - // load R heads from Q to shared memory - for (int64_t i = 0; i < D4/tph; ++i) { - if (sgitg == 0) { - pq4[hiisg*D4 + tph*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; - } + simdgroup_half8x8 mq[D8]; - ps4[hiisg*D4 + tph*i + tiih] = 0.0h; - } + for (int64_t i = 0; i < D8; ++i) { + simdgroup_load(mq[i], pq + i*8, T); + } - threadgroup_barrier(mem_flags::mem_threadgroup); + // TODO: this can be improved + device const float * mp[Q]; - half S = 0.0h; - half M = -INFINITY; + { + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - for (int64_t ic = sgitg; ic < ne11; ic += nsg) { - const half mv = mp ? mp[ic] : 0.0h; - if (mv == -INFINITY) { - continue; + for (int64_t j = 0; j < Q; ++j) { + if (iq1 + j < ne01) { + mp[j] = (device const float *) (mask + ((ir + j)%ne31)*nb31); + } else { + mp[j] = nullptr; + } + } } - device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); - device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); + for (int64_t iic = C*sgitg; iic < ne11; iic += C*nsg) { + // skip -INF blocks + // TODO: double-check this + { + float smc = -INFINITY; - half4 s4 = 0.0h; + for (int64_t j = 0; j < Q; ++j) { + const float mc = mp[j] ? mp[j][iic + tiisg] : -INFINITY; + smc = simd_max(max(smc, mc)); + } -#pragma unroll - for (int64_t i = 0; i < D4/tph; ++i) { - s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih]; - } + if (smc == -INFINITY) { + continue; + } + } + + // Q*K^T + { + simdgroup_half8x8 mk; - ss[hiisg*tph + tiih] = (s4.x + s4.y + s4.z + s4.w); + for (int cc = 0; cc < C/8; ++cc) { + simdgroup_half8x8 mqk = make_filled_simdgroup_matrix(0.h); - simdgroup_barrier(mem_flags::mem_threadgroup); + device const half * pk = (device const half *) ((device const char *) k + ((iic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); - if (tiih == 0) { - half s = 0.0h; + for (int64_t i = 0; i < D8; ++i) { + simdgroup_load(mk, pk + i*8, nb11/2, 0, true); -#pragma unroll - for (int64_t i = 0; i < tph; ++i) { - s += ss[hiisg*tph + i]; + simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); + } + + simdgroup_store(mqk, ss + 8*cc, T, 0, false); + } } - s = s*scale + mv; + // online softmax + for (int64_t j = 0; j < Q; ++j) { + const int64_t p = tiisg; - const half m = M; + const half s = ss[j*T + p]*scale + (mp[j][iic + p]); - M = max(M, s); + half m = M[j]; - const half ms = exp(m - M); - const half vs = exp(s - M); + M[j] = simd_max(max(M[j], s)); - S = S*ms + vs; + const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); + const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); - ss[2*hiisg + 0] = ms; - ss[2*hiisg + 1] = vs; - } + S[j] = S[j]*ms + simd_sum(vs); + + for (int64_t i = 0; i < L4; ++i) { + ps4[j*T4 + N4*i + tiisg] *= ms; + } + + ss[j*T + p] = vs; + } + + // (Q*K^T)*V + { + simdgroup_half8x8 mv; + + for (int64_t i = 0; i < D8; ++i) { + simdgroup_half8x8 mp[C/8]; + simdgroup_half8x8 mqkv; - simdgroup_barrier(mem_flags::mem_threadgroup); + simdgroup_load(mqkv, ps + i*8, T, 0, false); - const half ms = ss[2*hiisg + 0]; - const half vs = ss[2*hiisg + 1]; + for (int cc = 0; cc < C/8; ++cc) { + simdgroup_load(mp[cc], ss + 8*cc, T, 0, false); + } + + for (int cc = 0; cc < C/8; ++cc) { + device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); -#pragma unroll - for (int64_t i = 0; i < D4/tph; ++i) { - ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs; + simdgroup_load(mv, pv + i*8, nb21/2, 0, false); + + simdgroup_multiply_accumulate(mqkv, mp[cc], mv, mqkv); + } + + simdgroup_store(mqkv, ps + i*8, T, 0, false); + } + } } - } - if (tiih == 0) { - ss[2*hiisg + 0] = S; - ss[2*hiisg + 1] = M; + for (int64_t j = 0; j < Q; ++j) { + if (tiisg == 0) { + ss[j*T + 0] = S[j]; + ss[j*T + 1] = M[j]; + } + } } threadgroup_barrier(mem_flags::mem_threadgroup); // reduce the warps + // TODO: try parallel reduce if (sgitg == 0) { + half S = { 0.0h }; + half M = { -INFINITY }; + for (int64_t sg = 1; sg < nsg; ++sg) { - const half S0 = ss[ 2*hiisg + 0]; - const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0]; + for (int64_t j = 0; j < Q; ++j) { + const half S0 = ss[j*T + 0]; + const half S1 = ss[j*T + sg*(D + 1*C) + 0]; - const half M0 = ss[ 2*hiisg + 1]; - const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1]; + const half M0 = ss[j*T + 1]; + const half M1 = ss[j*T + sg*(D + 1*C) + 1]; - M = max(M0, M1); + M = max(M0, M1); - const half ms0 = exp(M0 - M); - const half ms1 = exp(M1 - M); + const half ms0 = exp(M0 - M); + const half ms1 = exp(M1 - M); - S = S0*ms0 + S1*ms1; + S = S0*ms0 + S1*ms1; - if (tiih == 0) { - ss[2*hiisg + 0] = S; - ss[2*hiisg + 1] = M; - } + if (tiisg == 0) { + ss[j*T + 0] = S; + ss[j*T + 1] = M; + } - for (int64_t i = 0; i < D4/tph; ++i) { - ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih]*ms1; + for (int64_t i = 0; i < L4; ++i) { + ps4[j*T4 + N4*i + tiisg] = ps4[j*T4 + N4*i + tiisg]*ms0 + ps4[j*T4 + sg*(D + 1*C)/4 + N4*i + tiisg]*ms1; + } } } - - for (int64_t i = 0; i < D4/tph; ++i) { - ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]/S; - } } simdgroup_barrier(mem_flags::mem_threadgroup); - // dst indices - const int64_t i1 = iq1; - const int64_t i2 = iq2; - const int64_t i3 = iq3; - device float4 * dst4 = (device float4 *) dst; if (sgitg == 0) { - for (int64_t i = 0; i < D4/tph; ++i) { - dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih] = (float4) ps4[hiisg*D4 + tph*i + tiih]; + for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { + const half S = ss[j*T + 0]; + + for (int64_t i = 0; i < L4; ++i) { + dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + N4*i + tiisg] = (float4) ps4[j*T4 + N4*i + tiisg]/S; + } } } } -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 2>; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 2>; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 2>; +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>; kernel void kernel_cpy_f16_f16( device const half * src0, diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 51a33c662da56..41ddfcca5b687 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1397,7 +1397,7 @@ struct test_flash_attn_ext : public test_case { } double max_nmse_err() override { - return 5e-4; + return 5e-5; } test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, @@ -1680,7 +1680,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); - test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 96, 8)); + test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 8)); + test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 7)); + test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 1)); #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer From d917746ddb053b73e868fd6e1854ac17b62bd863 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Jan 2024 15:00:49 +0200 Subject: [PATCH 21/58] metal : avoid redundant loads of the attention --- ggml-metal.metal | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 9b6ceec4e1066..785a60e50eba8 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2184,20 +2184,22 @@ kernel void kernel_flash_attn_ext_f16( ss[j*T + p] = vs; } + simdgroup_barrier(mem_flags::mem_none); + // (Q*K^T)*V { simdgroup_half8x8 mv; + simdgroup_half8x8 mp[C/8]; + for (int cc = 0; cc < C/8; ++cc) { + simdgroup_load(mp[cc], ss + 8*cc, T, 0, false); + } + for (int64_t i = 0; i < D8; ++i) { - simdgroup_half8x8 mp[C/8]; simdgroup_half8x8 mqkv; simdgroup_load(mqkv, ps + i*8, T, 0, false); - for (int cc = 0; cc < C/8; ++cc) { - simdgroup_load(mp[cc], ss + 8*cc, T, 0, false); - } - for (int cc = 0; cc < C/8; ++cc) { device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); From 432ad04ffaa445a3837b92dce1c03513009ab4ac Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Jan 2024 15:47:52 +0200 Subject: [PATCH 22/58] metal : scale and mask in matrix form --- ggml-metal.metal | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 785a60e50eba8..ae8f5caeaa75f 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2127,6 +2127,9 @@ kernel void kernel_flash_attn_ext_f16( } } + // prepare diagonal scale matrix + simdgroup_half8x8 mscale(scale); + for (int64_t iic = C*sgitg; iic < ne11; iic += C*nsg) { // skip -INF blocks // TODO: double-check this @@ -2153,11 +2156,16 @@ kernel void kernel_flash_attn_ext_f16( device const half * pk = (device const half *) ((device const char *) k + ((iic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); for (int64_t i = 0; i < D8; ++i) { - simdgroup_load(mk, pk + i*8, nb11/2, 0, true); + simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); } + // mqk = mqk*scale + mask + simdgroup_float8x8 mm; + simdgroup_load(mm, mp[0] + iic + 8*cc, nb31/sizeof(float), 0, false); + simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); + simdgroup_store(mqk, ss + 8*cc, T, 0, false); } } @@ -2166,7 +2174,8 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t j = 0; j < Q; ++j) { const int64_t p = tiisg; - const half s = ss[j*T + p]*scale + (mp[j][iic + p]); + //const half s = ss[j*T + p]*scale + (mp[j][iic + p]); + const half s = ss[j*T + p]; half m = M[j]; @@ -2203,7 +2212,7 @@ kernel void kernel_flash_attn_ext_f16( for (int cc = 0; cc < C/8; ++cc) { device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); - simdgroup_load(mv, pv + i*8, nb21/2, 0, false); + simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false); simdgroup_multiply_accumulate(mqkv, mp[cc], mv, mqkv); } From 40ea8cd1aca61294e1987bcb1051317827f1b145 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Jan 2024 16:31:39 +0200 Subject: [PATCH 23/58] metal : fix comment --- ggml-metal.metal | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index ae8f5caeaa75f..9ab9e16c3915a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1995,7 +1995,7 @@ typedef void (flash_attn_ext_f16_t)( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]); -template // head size, heads per threadgroup, queries per threadgroup +template // head size, queries per threadgroup, cache items per threadgroup kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2031,16 +2031,12 @@ kernel void kernel_flash_attn_ext_f16( uint3 ntg[[threads_per_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - const uint nsg = ntg.y; // number of simdgroups + const uint nsg = ntg.y; // number of simdgroups const int64_t iq3 = tgpig[2]; const int64_t iq2 = tgpig[1]; const int64_t iq1 = tgpig[0]*Q; - if (iq2 >= ne02) { - return; - } - const int64_t D4 = D/4; const int64_t N4 = N_SIMDWIDTH; const int64_t L4 = (D4 + N4 - 1)/N4; From f9ca5dcbe86a10cfa873814d5f754b7c9108f339 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Jan 2024 17:46:07 +0200 Subject: [PATCH 24/58] llama : avoid ggml_cast, use F32 query --- ggml-metal.m | 4 ++-- ggml-metal.metal | 3 ++- ggml.c | 31 +++++++++++++++++++++++++++---- ggml.h | 4 ++++ llama.cpp | 3 ++- tests/test-backend-ops.cpp | 16 +++++++--------- 6 files changed, 44 insertions(+), 17 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 7b161c69d5801..7b6762e6d9158 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2177,7 +2177,7 @@ static bool ggml_metal_graph_compute( case GGML_OP_FLASH_ATTN_EXT: { GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == GGML_TYPE_F32); struct ggml_tensor * src2 = gf->nodes[i]->src[2]; struct ggml_tensor * src3 = gf->nodes[i]->src[3]; @@ -2254,7 +2254,7 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&scale length:sizeof( float) atIndex:27]; // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - const int64_t nsg = ne01 < 4 ? 4 : 2; // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsg = ne01 < 4 ? 12 : 2; // simdgroups per threadgroup (a.k.a. warps) const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values) diff --git a/ggml-metal.metal b/ggml-metal.metal index 9ab9e16c3915a..c9e4dcfe99cd4 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2054,8 +2054,9 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t i = 0; i < L4; ++i) { // load heads from Q to shared memory for (int64_t j = sgitg; j < Q; j += nsg) { + device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); if (iq1 + j < ne01) { - pq4[j*T4 + N4*i + tiisg] = ((device const half4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)))[N4*i + tiisg]; + pq4[j*T4 + N4*i + tiisg] = (half4) q4[N4*i + tiisg]; } else { pq4[j*T4 + N4*i + tiisg] = 0.0h; } diff --git a/ggml.c b/ggml.c index 10df03c9c619b..5e515c03fdb9d 100644 --- a/ggml.c +++ b/ggml.c @@ -4178,6 +4178,8 @@ struct ggml_tensor * ggml_mul_mat( void ggml_mul_mat_set_prec( struct ggml_tensor * a, enum ggml_prec prec) { + GGML_ASSERT(a->op == GGML_OP_MUL_MAT); + const int32_t prec_i32 = (int32_t) prec; ggml_set_op_params_i32(a, 0, prec_i32); @@ -5781,6 +5783,16 @@ struct ggml_tensor * ggml_flash_attn_ext( return result; } +void ggml_flash_attn_ext_set_prec( + struct ggml_tensor * a, + enum ggml_prec prec) { + GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT); + + const int32_t prec_i32 = (int32_t) prec; + + ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos +} + // ggml_flash_ff struct ggml_tensor * ggml_flash_ff( @@ -13347,7 +13359,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(ne2 == N); GGML_ASSERT(P >= 0); - GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbq0 == sizeof(float)); GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t)); @@ -13408,6 +13420,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( float M = -INFINITY; float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32); + ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D); memset(V16, 0, D*sizeof(ggml_fp16_t)); @@ -13433,10 +13446,19 @@ static void ggml_compute_forward_flash_attn_ext_f16( float s; + // convert Q to F16 in V32 + { + const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); + + for (int64_t d = 0; d < D; ++d) { + Q16[d] = GGML_FP32_TO_FP16(pq[d]); + } + } + ggml_vec_dot_f16(D, &s, (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), - (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + Q16); s = s*scale + mv; @@ -13488,13 +13510,14 @@ static void ggml_compute_forward_flash_attn_ext( const struct ggml_tensor * v, const struct ggml_tensor * mask, struct ggml_tensor * dst) { - switch (q->type) { - case GGML_TYPE_F16: + switch (dst->op_params[1]) { + case GGML_PREC_DEFAULT: { ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); } break; default: { + // TODO: implement F32 precision GGML_ASSERT(false); } break; } diff --git a/ggml.h b/ggml.h index 7bca02f2a2c48..e2f74412fde1e 100644 --- a/ggml.h +++ b/ggml.h @@ -1633,6 +1633,10 @@ extern "C" { struct ggml_tensor * mask, float scale); + GGML_API void ggml_flash_attn_ext_set_prec( + struct ggml_tensor * a, + enum ggml_prec prec); + GGML_API struct ggml_tensor * ggml_flash_attn_back( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/llama.cpp b/llama.cpp index 4e6c9f9cc75ea..550caced4ae57 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4368,7 +4368,8 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(v, "v", il); - cur = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); + cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale); + ggml_flash_attn_ext_set_prec(cur, GGML_PREC_DEFAULT); //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 41ddfcca5b687..db1244876ce06 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1386,26 +1386,24 @@ struct test_leaky_relu : public test_case { // GGML_OP_FLASH_ATTN_EXT struct test_flash_attn_ext : public test_case { - const ggml_type typeq; const int64_t hs; // head size const int64_t nh; // num heads const int64_t kv; // kv size const int64_t nb; // batch size std::string vars() override { - return VARS_TO_STR5(typeq, hs, nh, kv, nb); + return VARS_TO_STR4(hs, nh, kv, nb); } double max_nmse_err() override { return 5e-5; } - test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, - int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) - : typeq(typeq), hs(hs), nh(nh), kv(kv), nb(nb) {} + test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) + : hs(hs), nh(nh), kv(kv), nb(nb) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nb, nh, 1); + ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1); @@ -1680,9 +1678,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); - test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 8)); - test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 7)); - test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 1)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 8)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 7)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 1)); #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer From 6e7cb0eeafe3c3145e9c9e398099b3c5b7641b5c Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Thu, 25 Jan 2024 11:04:51 -0500 Subject: [PATCH 25/58] update implementation --- ggml-cuda.cu | 100 ++++++++++++++++++++++++++------------------------- 1 file changed, 52 insertions(+), 48 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index b7ebfcc57835f..1b11a34bb16a2 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6113,7 +6113,7 @@ static __global__ void flash_attn_f32( } // based on metal version -template // D head size, R rows per block +template // D head size, Q queries per block, C cache items per blocks static __global__ void flash_attn_ext_f16( const char* __restrict__ q, const char* __restrict__ k, @@ -6141,62 +6141,64 @@ static __global__ void flash_attn_ext_f16( int ne1, int ne2, int ne3) { - int warp_id = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; + const int warp_id = threadIdx.y; + const int lane_id = threadIdx.x; - const int nwraps = blockDim.y; // number of warps - const int tph = WARP_SIZE / R; // threads per head + const int n_warps = blockDim.y; // number of warps const int iq3 = blockIdx.z; - const int iq2 = blockIdx.y * R + lane_id / tph; - const int iq1 = blockIdx.x; + const int iq2 = blockIdx.y; + const int iq1 = blockIdx.x * Q; - if(iq2 >= ne02) { - return; - } - - // broadcast - const int rk2 = ne02 / ne12; - const int rk3 = ne03 / ne13; - // assume the same K and V shape - // const int rv2 = ne02 / ne12; - // const int rv3 = ne03 / ne13; + const int D2 = D/2; + const int N4 = WARP_SIZE; + const int L2 = (D2 + N4 - 1)/N4; + const int D8 = D/8; - // kv indices - const int ik2 = iq2 / rk2; - const int ik3 = iq3 / rk3; - // const int iv2 = iq2 / rv2; - // const int iv3 = iq3 / rv3; + const int T = D + n_warps*(D + 1*C); // shared memory size per query in half + const int T2 = T/2; // shared memory size per query in half2 const half2 scale_h = __half2half2(__float2half(scale)); - const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - - const float * mp = mask ? (const float*)(mask + (ir % ne31)*nb31) : nullptr; - extern __shared__ char data_flash_attn_shmem[]; - half2* pq2 = (half2*)data_flash_attn_shmem; - half2* ps2 = (half2*)(data_flash_attn_shmem + warp_id * (R*D + 32) + 1*R*D); - half2* ss = (half2*)(data_flash_attn_shmem + warp_id * (R*D + 32) + 2*R*D); - - const int tiih = lane_id % tph; // thread index in head - const int hiiw = lane_id / tph; // head index in warp - - const int D2 = D / 2; // number of half2 to store head_dim row + half * pq = (half *) (data_flash_attn_shmem + 0*D); + half2 * pq2 = (half2 *) (data_flash_attn_shmem + 0*D); + half * ps = (half *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 1*D); + half2 * ps2 = (half2 *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 1*D); + half * ss = (half *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 2*D); + + for (int i = 0; i < L2; ++i) { + // load heads from Q to shared memory + for (int j = warp_id; j < Q; j += n_warps) { + if (iq1 + j < ne01) { + pq2[j*T2 + N4*i + lane_id] = ((half2*) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)))[N4*i + lane_id]; + } else { + pq2[j*T2 + N4*i + lane_id] = make_half2(0.0, 0.0); + } + } - // load R heads from Q to shared memory - for (int i = 0; i < D2/tph; ++i) { - if (warp_id == 0) { - pq2[hiiw*D2 + tph*i + tiih] = ((half2*)(q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; + // zero out shared memory + for (int j = 0; j < Q; ++j) { + ps2[j*T2 + N4*i + lane_id] = make_half2(0.0, 0.0); } + } - ps2[hiiw*D2 + tph*i + tiih] = make_half2(0.0, 0.0); + if (lane_id < C) { + for (int j = 0; j < Q; ++j) { + ss[j*T + 0 + lane_id] = 0.0; + } } + __syncthreads(); - half2 S = make_half2(0.0, 0.0); + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; + + half S[8] = { 0.0 }; +#if 0 half2 M = make_half2(-INFINITY, -INFINITY); + const float * mp = mask ? (const float*)(mask + (ir % ne31)*nb31) : nullptr; + for (int64_t ic = warp_id; ic < ne11; ic += nwraps) { const half2 mv = make_half2(mp ? mp[ic] : 0.0, 0.0); if (__hisinf(mv.x) == -1) { // mv == -INFINITY @@ -6302,6 +6304,7 @@ static __global__ void flash_attn_ext_f16( dst2[(i3*ne2*ne1 + i2 + i1*ne1)*D2 + tph*i + tiih] = __half22float2(ps2[hiiw*D2 + tph*i + tiih]); } } +#endif } @@ -10296,18 +10299,19 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * float scale; memcpy(&scale, KQV->op_params, sizeof(float)); - const int nwarps = 32; - const int nhpw = 2; // heads per warp + const int nwarps = Q->ne[1] < 4 ? 4 : 2; + const int nqpb = 2; // queries per block + const int ncpw = 32; // cache values per warp (does not work for other values) - dim3 blocks_num(Q->ne[1], (Q->ne[2] + nhpw - 1) / nhpw, Q->ne[3]); - dim3 block_dim(32 * nwarps, 1, 1); + dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); + dim3 block_dim(32, nwarps, 1); - int shmem = (nhpw*Q->ne[0]*2 + nwarps*(nhpw*Q->ne[0] + 32)) * (sizeof(float)/2); + int shmem = nqpb*(Q->ne[0] + nwarps*(Q->ne[0] + 1*ncpw))*(sizeof(float)/2); printf("shared memory: %d bytes [%i, %i, %i]\n\n", shmem, Q->ne[0], Q->ne[1], Q->ne[2]); switch (Q->ne[0]) { case 64: - flash_attn_ext_f16<64, 2> + flash_attn_ext_f16<64, 8, 32> <<>> ( (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key @@ -10324,7 +10328,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * ); break; case 80: - flash_attn_ext_f16<80, 2> + flash_attn_ext_f16<80, 8, 32> <<>> ( (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key @@ -10341,7 +10345,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * ); break; case 128: - flash_attn_ext_f16<128, 2> + flash_attn_ext_f16<128, 8, 32> <<>> ( (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key From 6fea843b246409a3c4b26156745a89e4ba01029b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Jan 2024 17:59:41 +0200 Subject: [PATCH 26/58] metal : add parallel reduce version (disabled) --- ggml-metal.m | 2 +- ggml-metal.metal | 42 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 7b6762e6d9158..cf7880c822db5 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2254,7 +2254,7 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&scale length:sizeof( float) atIndex:27]; // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - const int64_t nsg = ne01 < 4 ? 12 : 2; // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsg = ne01 < 4 ? 12 : 4; // simdgroups per threadgroup (a.k.a. warps) const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values) diff --git a/ggml-metal.metal b/ggml-metal.metal index c9e4dcfe99cd4..6eb2825df558b 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2230,7 +2230,7 @@ kernel void kernel_flash_attn_ext_f16( threadgroup_barrier(mem_flags::mem_threadgroup); // reduce the warps - // TODO: try parallel reduce +#if 1 if (sgitg == 0) { half S = { 0.0h }; half M = { -INFINITY }; @@ -2261,6 +2261,46 @@ kernel void kernel_flash_attn_ext_f16( } } } +#else + // parallel reduce + // NOTE: this is significantly slower than the serial version above, likely due to the small number of warps + { + half S = { 0.0h }; + half M = { -INFINITY }; + + for (int64_t sg = nsg/2; sg > 0; sg /= 2) { + if (sgitg >= sg) { + continue; + } + + for (int64_t j = 0; j < Q; ++j) { + const half S0 = ss[j*T + 0]; + const half S1 = ss[j*T + sg*(D + 1*C) + 0]; + + const half M0 = ss[j*T + 1]; + const half M1 = ss[j*T + sg*(D + 1*C) + 1]; + + M = max(M0, M1); + + const half ms0 = exp(M0 - M); + const half ms1 = exp(M1 - M); + + S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[j*T + 0] = S; + ss[j*T + 1] = M; + } + + for (int64_t i = 0; i < L4; ++i) { + ps4[j*T4 + N4*i + tiisg] = ps4[j*T4 + N4*i + tiisg]*ms0 + ps4[j*T4 + sg*(D + 1*C)/4 + N4*i + tiisg]*ms1; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } +#endif simdgroup_barrier(mem_flags::mem_threadgroup); From 0a481fe1a9f7f8618cd64744af7c3d5900ac4a8e Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Fri, 26 Jan 2024 20:14:02 -0500 Subject: [PATCH 27/58] integrate tensor cores --- ggml-cuda.cu | 249 +++++++++++++++++++++++++++++++++++---------------- 1 file changed, 170 insertions(+), 79 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 1b11a34bb16a2..5cb0656063499 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -104,6 +104,7 @@ #include #include #include +#include #if CUDART_VERSION < 11020 #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED @@ -621,6 +622,14 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL } +static __device__ __forceinline__ __half warp_reduce_sum(__half x) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x += __shfl_xor_sync(0xffffffff, x, mask, 32); + } + return x; +} + static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { @@ -642,6 +651,19 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX } +static __device__ __forceinline__ half warp_reduce_max(half x) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = __hmax(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + } + return x; +#else + (void) x; + bad_arch(); +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +} + static __device__ __forceinline__ float op_repeat(const float a, const float b) { return b; GGML_UNUSED(a); @@ -6112,6 +6134,10 @@ static __global__ void flash_attn_f32( } } +typedef nvcuda::wmma::fragment half16x16_a; +typedef nvcuda::wmma::fragment half16x16_b; +typedef nvcuda::wmma::fragment half16x16_acc; + // based on metal version template // D head size, Q queries per block, C cache items per blocks static __global__ void flash_attn_ext_f16( @@ -6152,17 +6178,17 @@ static __global__ void flash_attn_ext_f16( const int D2 = D/2; const int N4 = WARP_SIZE; const int L2 = (D2 + N4 - 1)/N4; - const int D8 = D/8; + const int D16 = D/16; const int T = D + n_warps*(D + 1*C); // shared memory size per query in half - const int T2 = T/2; // shared memory size per query in half2 + const int T2 = T/2; // shared memory size per query in half2 - const half2 scale_h = __half2half2(__float2half(scale)); + const half scale_h = __float2half(scale); extern __shared__ char data_flash_attn_shmem[]; - - half * pq = (half *) (data_flash_attn_shmem + 0*D); - half2 * pq2 = (half2 *) (data_flash_attn_shmem + 0*D); + // pq + half * pq = (half *) (data_flash_attn_shmem + 0*D); + half2 * pq2 = (half2 *) (data_flash_attn_shmem + 0*D); half * ps = (half *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 1*D); half2 * ps2 = (half2 *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 1*D); half * ss = (half *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 2*D); @@ -6191,120 +6217,185 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); - const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; + { + half S[Q] = { 0.0 }; + half M[Q] = { -INFINITY }; - half S[8] = { 0.0 }; -#if 0 - half2 M = make_half2(-INFINITY, -INFINITY); + // assume K and V are same shape + const int ne22 = ne12; + const int ne23 = ne13; - const float * mp = mask ? (const float*)(mask + (ir % ne31)*nb31) : nullptr; + const int nb21 = nb11; + const int nb22 = nb12; + const int nb23 = nb13; - for (int64_t ic = warp_id; ic < ne11; ic += nwraps) { - const half2 mv = make_half2(mp ? mp[ic] : 0.0, 0.0); - if (__hisinf(mv.x) == -1) { // mv == -INFINITY - continue; - } + // broadcast + const int rk2 = ne02/ne12; + const int rk3 = ne03/ne13; - half2 * pk2 = (half2 *) ((char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); - half2 * pv2 = (half2 *) ((char *) v + (ic*nb11 + ik2*nb12 + ik3*nb13)); // assumes V same shape of K + const int rv2 = ne02/ne22; + const int rv3 = ne03/ne23; - half2 s2 = make_half2(0.0, 0.0); + // k indices + const int ik2 = iq2 / rk2; + const int ik3 = iq3 / rk3; -#pragma unroll - for (int i = 0; i < D2/tph; ++i) { - s2 = pq2[hiiw*D2 + tph*i + tiih] * pk2[tph*i + tiih] + s2; + // v indices + const int iv2 = iq2 / rv2; + const int iv3 = iq3 / rv3; + + // TODO: this can be improved + float * mp[Q]; + + { + const int ir = iq3*ne02*ne01 + iq2*ne01 + iq1; + + for (int j = 0; j < Q; ++j) { + if (iq1 + j < ne01) { + mp[j] = (float *)(mask + ((ir + j)%ne31) * nb31); + } else { + mp[j] = nullptr; + } + } } - ss[hiiw*tph + tiih] = __half2half2(s2.x + s2.y); + for (int iic = C*warp_id; iic < ne11; iic += C*n_warps) { + // skip -INF blocks + // TODO: double-check this + { + float smc = -INFINITY; - __syncthreads(); + for (int j = 0; j < Q; ++j) { + const float mc = mp[j] ? mp[j][iic + lane_id] : -INFINITY; + smc = warp_reduce_max(max(smc, mc)); + } - if (tiih == 0) { - half2 s = make_half2(0.0, 0.0); + if (smc == -INFINITY) { + continue; + } + } -#pragma unroll - for (int i = 0; i < tph; ++i) { - s += ss[hiiw*tph + i]; + // Q*K^T + { + half16x16_a mq{}; + half16x16_b mk{}; + half16x16_acc mqk{}; + + for (int cc = 0; cc < C/16; ++cc) { + nvcuda::wmma::fill_fragment(mqk, 0); // re fetch + + const half * pk = (const half *) (k + ((iic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13)); + + for(int i = 0; i < D16;i ++) { + nvcuda::wmma::load_matrix_sync(mq, pq + i*16, T); + nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); + nvcuda::wmma::mma_sync(mqk, mq, mk, mqk); + } + + nvcuda::wmma::store_matrix_sync(ss + 16*cc, mqk, T, nvcuda::wmma::mem_col_major); + } } - s = s*scale_h + mv; // s*scale + mv + // online softmax + for (int64_t j = 0; j < Q; ++j) { + const int64_t p = lane_id; - half2 m = M; + const half s = ss[j*T + p]*scale_h + __float2half(mp[j][iic + p]); - M = __hmax2(M, s); + half m = M[j]; - half2 ms = h2exp(m - M); - half2 vs = h2exp(s - M); + M[j] = warp_reduce_max(__hmax(M[j], s)); - S = S*ms + vs; + const half ms = __hisinf(m) == -1 ? 0.0 : hexp(m - M[j]); + const half vs = __hisinf(s) == -1 ? 0.0 : hexp(s - M[j]); - ss[2*hiiw + 0] = ms; - ss[2*hiiw + 1] = vs; - } + S[j] = S[j]*ms + warp_reduce_sum(vs); - __syncthreads(); + ss[j*T + p] = vs; + } - half2 ms = ss[2*hiiw + 0]; - half2 vs = ss[2*hiiw + 1]; + __syncthreads(); -#pragma unroll - for (int i = 0; i < D2/tph; ++i) { - ps2[hiiw*D2 + tph*i + tiih] = ps2[hiiw*D2 + tph*i + tiih]*ms + pv2[tph*i + tiih]*vs; + // (Q*K^T)*V + { + half16x16_acc mqkv{}; + half16x16_a mqk{}; + half16x16_b mv{}; + + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::fill_fragment(mqkv, 0); + + for (int cc = 0; cc < C/16; ++cc) { + const half * pv = (const half *) ((const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + + nvcuda::wmma::load_matrix_sync(mqk, ss + cc*16, T); + nvcuda::wmma::load_matrix_sync(mv, pv + i*16, nb21/sizeof(half)); + + nvcuda::wmma::mma_sync(mqkv, mqk, mv, mqkv); + } + + nvcuda::wmma::store_matrix_sync(ps + i*16, mqkv, T, nvcuda::wmma::mem_col_major); + } + } } - } - if (tiih == 0) { - ss[2*hiiw + 0] = S; - ss[2*hiiw + 1] = M; + for (int64_t j = 0; j < Q; ++j) { + if (lane_id == 0) { + ss[j*T + 0] = S[j]; + ss[j*T + 1] = M[j]; + } + } } __syncthreads(); // reduce the warps + // TODO: try parallel reduce if (warp_id == 0) { - for (int sg = 1; sg < nwraps; ++sg) { - half2 S0 = ss[ 2*hiiw + 0]; - half2 S1 = ss[sg*(R*D + 32) + 2*hiiw + 0]; + half S = 0.0; + half M = -INFINITY; - half2 M0 = ss[ 2*hiiw + 1]; - half2 M1 = ss[sg*(R*D + 32) + 2*hiiw + 1]; + for (int64_t sg = 1; sg < n_warps; ++sg) { + for (int64_t j = 0; j < Q; ++j) { + const half S0 = ss[j*T + 0]; + const half S1 = ss[j*T + sg*(D + 1*C) + 0]; - M = __hmax2(M0, M1); + const half M0 = ss[j*T + 1]; + const half M1 = ss[j*T + sg*(D + 1*C) + 1]; - half2 ms0 = h2exp(M0 - M); - half2 ms1 = h2exp(M1 - M); + M = __hmax(M0, M1); - S = S0*ms0 + S1*ms1; + const half ms0 = hexp(M0 - M); + const half ms1 = hexp(M1 - M); - if (tiih == 0) { - ss[2*hiiw + 0] = S; - ss[2*hiiw + 1] = M; - } + S = S0*ms0 + S1*ms1; - for (int i = 0; i < D2/tph; ++i) { - ps2[hiiw*D2 + tph*i + tiih] = ps2[hiiw*D2 + tph*i + tiih]*ms0 + ps2[sg*(R*D + 32)/4 + hiiw*D2 + tph*i + tiih]*ms1; - } - } + if (lane_id == 0) { + ss[j*T + 0] = S; + ss[j*T + 1] = M; + } - for (int i = 0; i < D2/tph; ++i) { - ps2[hiiw*D2 + tph*i + tiih] = __h2div(ps2[hiiw*D2 + tph*i + tiih], S); + for (int64_t i = 0; i < L2; ++i) { + ps2[j*T2 + N4*i + lane_id] = ps2[j*T2 + N4*i + lane_id]*__half2half2(ms0) + ps2[j*T2 + sg*(D + 1*C)/4 + N4*i + lane_id]*__half2half2(ms1); + } + } } } __syncthreads(); - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; - float2 * dst2 = (float2 *) kqv; + if (warp_id == 0) { - for (int i = 0; i < D2/tph; ++i) { - dst2[(i3*ne2*ne1 + i2 + i1*ne1)*D2 + tph*i + tiih] = __half22float2(ps2[hiiw*D2 + tph*i + tiih]); + for (int j = 0; j < Q && iq1 + j < ne01; ++j) { + half2 S = __half2half2(ss[j*T + 0]); + + for (int i = 0; i < L2; ++i) { + dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + N4*i + lane_id] = __half22float2(ps2[j*T2 + N4*i + lane_id]/S); + } } } -#endif + } @@ -10300,7 +10391,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * memcpy(&scale, KQV->op_params, sizeof(float)); const int nwarps = Q->ne[1] < 4 ? 4 : 2; - const int nqpb = 2; // queries per block + const int nqpb = 16; // queries per block const int ncpw = 32; // cache values per warp (does not work for other values) dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); @@ -10311,7 +10402,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * switch (Q->ne[0]) { case 64: - flash_attn_ext_f16<64, 8, 32> + flash_attn_ext_f16<64, 16, 32> <<>> ( (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key @@ -10328,7 +10419,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * ); break; case 80: - flash_attn_ext_f16<80, 8, 32> + flash_attn_ext_f16<80, 16, 32> <<>> ( (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key @@ -10345,7 +10436,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * ); break; case 128: - flash_attn_ext_f16<128, 8, 32> + flash_attn_ext_f16<128, 16, 32> <<>> ( (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key From 2455a8d6c3b2e49cc19155aeb8e12438fd6a42fa Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Sat, 27 Jan 2024 12:23:40 -0500 Subject: [PATCH 28/58] update impl --- ggml-cuda.cu | 43 +++++++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 5cb0656063499..ecfa98c4e4763 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6134,9 +6134,9 @@ static __global__ void flash_attn_f32( } } -typedef nvcuda::wmma::fragment half16x16_a; -typedef nvcuda::wmma::fragment half16x16_b; -typedef nvcuda::wmma::fragment half16x16_acc; +typedef nvcuda::wmma::fragment half16x16_a; +typedef nvcuda::wmma::fragment half16x16_b; +typedef nvcuda::wmma::fragment half16x16_acc; // based on metal version template // D head size, Q queries per block, C cache items per blocks @@ -6196,8 +6196,9 @@ static __global__ void flash_attn_ext_f16( for (int i = 0; i < L2; ++i) { // load heads from Q to shared memory for (int j = warp_id; j < Q; j += n_warps) { + const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); if (iq1 + j < ne01) { - pq2[j*T2 + N4*i + lane_id] = ((half2*) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)))[N4*i + lane_id]; + pq2[j*T2 + N4*i + lane_id] = __float22half2_rn(q2[N4*i + lane_id]); } else { pq2[j*T2 + N4*i + lane_id] = make_half2(0.0, 0.0); } @@ -6218,8 +6219,8 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); { - half S[Q] = { 0.0 }; - half M[Q] = { -INFINITY }; + half S[Q] = { 0.0 }; // could be half2 S[Q/2] = how fill this array with zeros?? + half M[Q] = { -INFINITY }; // could be half2 M[Q/2] = better register utilization // assume K and V are same shape const int ne22 = ne12; @@ -6277,12 +6278,12 @@ static __global__ void flash_attn_ext_f16( // Q*K^T { - half16x16_a mq{}; - half16x16_b mk{}; - half16x16_acc mqk{}; + half16x16_a mq; + half16x16_b mk; + half16x16_acc mqk; for (int cc = 0; cc < C/16; ++cc) { - nvcuda::wmma::fill_fragment(mqk, 0); // re fetch + nvcuda::wmma::fill_fragment(mqk, 0); const half * pk = (const half *) (k + ((iic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13)); @@ -6297,8 +6298,8 @@ static __global__ void flash_attn_ext_f16( } // online softmax - for (int64_t j = 0; j < Q; ++j) { - const int64_t p = lane_id; + for (int j = 0; j < Q; ++j) { + const int p = lane_id; const half s = ss[j*T + p]*scale_h + __float2half(mp[j][iic + p]); @@ -6311,6 +6312,10 @@ static __global__ void flash_attn_ext_f16( S[j] = S[j]*ms + warp_reduce_sum(vs); + for (int i = 0; i < L2; ++i) { + ps2[j*T2 + N4*i + lane_id] *= __half2half2(ms); + } + ss[j*T + p] = vs; } @@ -6318,9 +6323,9 @@ static __global__ void flash_attn_ext_f16( // (Q*K^T)*V { - half16x16_acc mqkv{}; - half16x16_a mqk{}; - half16x16_b mv{}; + half16x16_acc mqkv; + half16x16_a mqk; + half16x16_b mv; for (int64_t i = 0; i < D16; ++i) { nvcuda::wmma::fill_fragment(mqkv, 0); @@ -6353,7 +6358,7 @@ static __global__ void flash_attn_ext_f16( // TODO: try parallel reduce if (warp_id == 0) { half S = 0.0; - half M = -INFINITY; + half M = __float2half(-INFINITY); for (int64_t sg = 1; sg < n_warps; ++sg) { for (int64_t j = 0; j < Q; ++j) { @@ -6395,10 +6400,8 @@ static __global__ void flash_attn_ext_f16( } } } - } - template static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { @@ -10366,7 +10369,7 @@ inline void ggml_cuda_flash_attn(const ggml_tensor * Q, const ggml_tensor * K, c inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, const ggml_tensor * mask, ggml_tensor * KQV) { - GGML_ASSERT(Q->type == GGML_TYPE_F16); + GGML_ASSERT(Q->type == GGML_TYPE_F32); GGML_ASSERT(K->type == GGML_TYPE_F16); GGML_ASSERT(V->type == GGML_TYPE_F16); GGML_ASSERT(mask->type == GGML_TYPE_F32); @@ -10390,7 +10393,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * float scale; memcpy(&scale, KQV->op_params, sizeof(float)); - const int nwarps = Q->ne[1] < 4 ? 4 : 2; + const int nwarps = Q->ne[1] < 4 ? 12 : 4; const int nqpb = 16; // queries per block const int ncpw = 32; // cache values per warp (does not work for other values) From 77f6976a87f6d034cf0f7a77e14a011da7901911 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 28 Jan 2024 13:15:00 +0200 Subject: [PATCH 29/58] metal : move output into local memory + optimize - the result from each simdgroup now stays in the registers - significantly reduced SRAM usage - more efficient skipping of -INF blocks - avoid simdgroup barrier in hot loop - add comments --- ggml-metal.m | 12 +-- ggml-metal.metal | 220 ++++++++++++++++++++++------------------------- 2 files changed, 110 insertions(+), 122 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index eabc16f416645..a7e126bff5318 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2213,14 +2213,14 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - const int64_t nsg = ne01 < 4 ? 12 : 4; // simdgroups per threadgroup (a.k.a. warps) - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values) + const int64_t ncpsg = 32; // cache values per simdgroup + + // simdgroups per threadgroup (a.k.a. warps) + // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) + const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/32, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; - //const size_t smem = nqptg*(nhptg*ne00 + nsg*(nhptg*ne00 + 256))*(sizeof(float)/2); - const size_t smem = nqptg*(ne00 + nsg*(ne00 + 1*ncpsg))*(sizeof(float)/2); + const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); diff --git a/ggml-metal.metal b/ggml-metal.metal index 6eb2825df558b..b564f014de2b6 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1995,6 +1995,7 @@ typedef void (flash_attn_ext_f16_t)( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]); +// ref: https://arxiv.org/pdf/2307.08691.pdf template // head size, queries per threadgroup, cache items per threadgroup kernel void kernel_flash_attn_ext_f16( device const char * q, @@ -2038,39 +2039,45 @@ kernel void kernel_flash_attn_ext_f16( const int64_t iq1 = tgpig[0]*Q; const int64_t D4 = D/4; - const int64_t N4 = N_SIMDWIDTH; - const int64_t L4 = (D4 + N4 - 1)/N4; const int64_t D8 = D/8; + const int64_t NW = N_SIMDWIDTH; + const int64_t L4 = (D4 + NW - 1)/NW; + const int64_t SH = (C + Q); // shared memory per simdgroup in (half) - const int64_t T = D + nsg*(D + 1*C); // shared memory size per query in half - const int64_t T4 = T/4; // shared memory size per query in half4 + const int64_t T = D + nsg*SH; // shared memory size per query in (half) + const int64_t T4 = T/4; // shared memory size per query in (half4) - threadgroup half * pq = (threadgroup half *) (shared + 0*D); - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*D); - threadgroup half * ps = (threadgroup half *) (shared + sgitg*(D + 1*C) + 1*D); - threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(D + 1*C) + 1*D); - threadgroup half * ss = (threadgroup half *) (shared + sgitg*(D + 1*C) + 2*D); + threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // scratch buffer for attention + threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for diagonal matrix + + // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) + simdgroup_half8x8 lo[D8]; for (int64_t i = 0; i < L4; ++i) { // load heads from Q to shared memory for (int64_t j = sgitg; j < Q; j += nsg) { device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); if (iq1 + j < ne01) { - pq4[j*T4 + N4*i + tiisg] = (half4) q4[N4*i + tiisg]; + sq4[j*T4 + NW*i + tiisg] = (half4) q4[NW*i + tiisg]; } else { - pq4[j*T4 + N4*i + tiisg] = 0.0h; + sq4[j*T4 + NW*i + tiisg] = 0.0h; } } + } - // zero out shared memory - for (int64_t j = 0; j < Q; ++j) { - ps4[j*T4 + N4*i + tiisg] = 0.0h; - } + // zero out lo + for (int64_t i = 0; i < D8; ++i) { + lo[i] = make_filled_simdgroup_matrix(0.0h); } + // zero out shared memory SH if (tiisg < C) { for (int64_t j = 0; j < Q; ++j) { - ss[j*T + 0 + tiisg] = 0.0h; + ss[j*T + tiisg] = 0.0h; + if (tiisg < Q) { + ss[j*T + C + tiisg] = 0.0h; + } } } @@ -2103,46 +2110,24 @@ kernel void kernel_flash_attn_ext_f16( const int64_t iv2 = iq2 / rv2; const int64_t iv3 = iq3 / rv3; + // load the queries from shared memory into local memory simdgroup_half8x8 mq[D8]; for (int64_t i = 0; i < D8; ++i) { - simdgroup_load(mq[i], pq + i*8, T); + simdgroup_load(mq[i], sq + i*8, T); } - // TODO: this can be improved - device const float * mp[Q]; + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - { - const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - - for (int64_t j = 0; j < Q; ++j) { - if (iq1 + j < ne01) { - mp[j] = (device const float *) (mask + ((ir + j)%ne31)*nb31); - } else { - mp[j] = nullptr; - } - } - } + // pointer to the mask + device const float * mp = (device const float *) (mask + (ir%ne31)*nb31); // prepare diagonal scale matrix simdgroup_half8x8 mscale(scale); - for (int64_t iic = C*sgitg; iic < ne11; iic += C*nsg) { - // skip -INF blocks - // TODO: double-check this - { - float smc = -INFINITY; - - for (int64_t j = 0; j < Q; ++j) { - const float mc = mp[j] ? mp[j][iic + tiisg] : -INFINITY; - smc = simd_max(max(smc, mc)); - } - - if (smc == -INFINITY) { - continue; - } - } - + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int64_t ic = C*sgitg; ic < ne11; ic += C*nsg) { // Q*K^T { simdgroup_half8x8 mk; @@ -2150,7 +2135,7 @@ kernel void kernel_flash_attn_ext_f16( for (int cc = 0; cc < C/8; ++cc) { simdgroup_half8x8 mqk = make_filled_simdgroup_matrix(0.h); - device const half * pk = (device const half *) ((device const char *) k + ((iic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); + device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); for (int64_t i = 0; i < D8; ++i) { simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); @@ -2160,65 +2145,77 @@ kernel void kernel_flash_attn_ext_f16( // mqk = mqk*scale + mask simdgroup_float8x8 mm; - simdgroup_load(mm, mp[0] + iic + 8*cc, nb31/sizeof(float), 0, false); + simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(float), 0, false); simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); simdgroup_store(mqk, ss + 8*cc, T, 0, false); } } + // used to detect blocks full of -INF + half smax = -INFINITY; + // online softmax for (int64_t j = 0; j < Q; ++j) { const int64_t p = tiisg; - //const half s = ss[j*T + p]*scale + (mp[j][iic + p]); const half s = ss[j*T + p]; - half m = M[j]; - + smax = simd_max(max(smax, s)); M[j] = simd_max(max(M[j], s)); + const half m = M[j]; + const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); S[j] = S[j]*ms + simd_sum(vs); - for (int64_t i = 0; i < L4; ++i) { - ps4[j*T4 + N4*i + tiisg] *= ms; + // create an 8x8 diagonal matrix for rescaling the output + if (p == j) { + ss[j*T + C + j] = ms; } + // the P matrix from the paper (Q rows, C columns) ss[j*T + p] = vs; } - simdgroup_barrier(mem_flags::mem_none); + // skip -INF blocks + if (smax == -INFINITY) { + continue; + } - // (Q*K^T)*V + // O = diag(ms)*O { - simdgroup_half8x8 mv; + simdgroup_half8x8 mm; - simdgroup_half8x8 mp[C/8]; - for (int cc = 0; cc < C/8; ++cc) { - simdgroup_load(mp[cc], ss + 8*cc, T, 0, false); - } + simdgroup_load(mm, ss + C, T, 0, false); for (int64_t i = 0; i < D8; ++i) { - simdgroup_half8x8 mqkv; + simdgroup_multiply(lo[i], mm, lo[i]); + } + } - simdgroup_load(mqkv, ps + i*8, T, 0, false); + // O = O + (Q*K^T)*V + { + simdgroup_half8x8 mv; + + for (int cc = 0; cc < C/8; ++cc) { + simdgroup_half8x8 mp; + simdgroup_load(mp, ss + 8*cc, T, 0, false); - for (int cc = 0; cc < C/8; ++cc) { - device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + for (int64_t i = 0; i < D8; ++i) { + device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false); - simdgroup_multiply_accumulate(mqkv, mp[cc], mv, mqkv); + simdgroup_multiply_accumulate(lo[i], mp, mv, lo[i]); } - - simdgroup_store(mqkv, ps + i*8, T, 0, false); } } } + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) for (int64_t j = 0; j < Q; ++j) { if (tiisg == 0) { ss[j*T + 0] = S[j]; @@ -2227,58 +2224,30 @@ kernel void kernel_flash_attn_ext_f16( } } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // reduce the warps -#if 1 - if (sgitg == 0) { + // reduce the warps sequentially + for (int64_t sg = 1; sg < nsg; ++sg) { half S = { 0.0h }; half M = { -INFINITY }; - for (int64_t sg = 1; sg < nsg; ++sg) { - for (int64_t j = 0; j < Q; ++j) { - const half S0 = ss[j*T + 0]; - const half S1 = ss[j*T + sg*(D + 1*C) + 0]; - - const half M0 = ss[j*T + 1]; - const half M1 = ss[j*T + sg*(D + 1*C) + 1]; - - M = max(M0, M1); - - const half ms0 = exp(M0 - M); - const half ms1 = exp(M1 - M); - - S = S0*ms0 + S1*ms1; - - if (tiisg == 0) { - ss[j*T + 0] = S; - ss[j*T + 1] = M; - } + threadgroup_barrier(mem_flags::mem_threadgroup); - for (int64_t i = 0; i < L4; ++i) { - ps4[j*T4 + N4*i + tiisg] = ps4[j*T4 + N4*i + tiisg]*ms0 + ps4[j*T4 + sg*(D + 1*C)/4 + N4*i + tiisg]*ms1; - } + // each simdgroup stores its output to shared memory, reusing sq4 + if (sgitg == sg) { + for (int64_t i = 0; i < D8; ++i) { + simdgroup_store(lo[i], sq + i*8, T, 0, false); } } - } -#else - // parallel reduce - // NOTE: this is significantly slower than the serial version above, likely due to the small number of warps - { - half S = { 0.0h }; - half M = { -INFINITY }; - for (int64_t sg = nsg/2; sg > 0; sg /= 2) { - if (sgitg >= sg) { - continue; - } + threadgroup_barrier(mem_flags::mem_threadgroup); + // the first simdgroup accumulates the results from the other simdgroups + if (sgitg == 0) { for (int64_t j = 0; j < Q; ++j) { - const half S0 = ss[j*T + 0]; - const half S1 = ss[j*T + sg*(D + 1*C) + 0]; + const half S0 = ss[j*T + 0]; + const half S1 = ss[j*T + sg*SH + 0]; - const half M0 = ss[j*T + 1]; - const half M1 = ss[j*T + sg*(D + 1*C) + 1]; + const half M0 = ss[j*T + 1]; + const half M1 = ss[j*T + sg*SH + 1]; M = max(M0, M1); @@ -2290,28 +2259,47 @@ kernel void kernel_flash_attn_ext_f16( if (tiisg == 0) { ss[j*T + 0] = S; ss[j*T + 1] = M; - } - for (int64_t i = 0; i < L4; ++i) { - ps4[j*T4 + N4*i + tiisg] = ps4[j*T4 + N4*i + tiisg]*ms0 + ps4[j*T4 + sg*(D + 1*C)/4 + N4*i + tiisg]*ms1; + ss[j*T + C + j ] = ms0; + ss[j*T + C + j + sg*SH] = ms1; } } - threadgroup_barrier(mem_flags::mem_threadgroup); + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + { + simdgroup_half8x8 t; + simdgroup_half8x8 ms0; + simdgroup_half8x8 ms1; + + simdgroup_load(ms0, ss + C, T, 0, false); + simdgroup_load(ms1, ss + C + sg*SH, T, 0, false); + + for (int64_t i = 0; i < D8; ++i) { + simdgroup_load (t, sq + i*8, T, 0, false); + simdgroup_multiply(t, ms1, t); + + simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); + } + } } } -#endif - simdgroup_barrier(mem_flags::mem_threadgroup); + // store result to shared memory (reuse sq4) + if (sgitg == 0) { + for (int64_t i = 0; i < D8; ++i) { + simdgroup_store(lo[i], sq + i*8, T, 0, false); + } + } device float4 * dst4 = (device float4 *) dst; + // final rescale with 1/S and store to global memory if (sgitg == 0) { for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { const half S = ss[j*T + 0]; for (int64_t i = 0; i < L4; ++i) { - dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + N4*i + tiisg] = (float4) ps4[j*T4 + N4*i + tiisg]/S; + dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + NW*i + tiisg] = (float4) sq4[j*T4 + NW*i + tiisg]/S; } } } From ecc466a460abc7ad73df3b22a3e0957170bcf7b9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 28 Jan 2024 15:42:57 +0200 Subject: [PATCH 30/58] metal : add tests, fix scaling, support C > 32 --- ggml-metal.m | 6 ++-- ggml-metal.metal | 62 ++++++++++++++++++++------------------ tests/test-backend-ops.cpp | 14 ++++++--- 3 files changed, 46 insertions(+), 36 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index a7e126bff5318..484ef89398e7a 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2213,12 +2213,12 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! (multiple of 8) + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! (multiple of 32) // simdgroups per threadgroup (a.k.a. warps) // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/32, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; + const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); diff --git a/ggml-metal.metal b/ggml-metal.metal index b564f014de2b6..7b604eb61a177 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2041,7 +2041,6 @@ kernel void kernel_flash_attn_ext_f16( const int64_t D4 = D/4; const int64_t D8 = D/8; const int64_t NW = N_SIMDWIDTH; - const int64_t L4 = (D4 + NW - 1)/NW; const int64_t SH = (C + Q); // shared memory per simdgroup in (half) const int64_t T = D + nsg*SH; // shared memory size per query in (half) @@ -2054,14 +2053,15 @@ kernel void kernel_flash_attn_ext_f16( // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) simdgroup_half8x8 lo[D8]; - for (int64_t i = 0; i < L4; ++i) { - // load heads from Q to shared memory - for (int64_t j = sgitg; j < Q; j += nsg) { - device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + // load heads from Q to shared memory + for (int64_t j = sgitg; j < Q; j += nsg) { + device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + + for (int64_t i = tiisg; i < D4; i += NW) { if (iq1 + j < ne01) { - sq4[j*T4 + NW*i + tiisg] = (half4) q4[NW*i + tiisg]; + sq4[j*T4 + i] = (half4) q4[i]; } else { - sq4[j*T4 + NW*i + tiisg] = 0.0h; + sq4[j*T4 + i] = 0.0h; } } } @@ -2072,12 +2072,9 @@ kernel void kernel_flash_attn_ext_f16( } // zero out shared memory SH - if (tiisg < C) { - for (int64_t j = 0; j < Q; ++j) { - ss[j*T + tiisg] = 0.0h; - if (tiisg < Q) { - ss[j*T + C + tiisg] = 0.0h; - } + for (int64_t j = 0; j < Q; ++j) { + for (int64_t i = tiisg; i < SH; i += NW) { + ss[j*T + i] = 0.0h; } } @@ -2157,27 +2154,34 @@ kernel void kernel_flash_attn_ext_f16( // online softmax for (int64_t j = 0; j < Q; ++j) { - const int64_t p = tiisg; - - const half s = ss[j*T + p]; + const half m = M[j]; - smax = simd_max(max(smax, s)); - M[j] = simd_max(max(M[j], s)); + for (int64_t p = tiisg; p < C; p += NW) { + const half s = ss[j*T + p]; - const half m = M[j]; + smax = simd_max(max(smax, s)); + M[j] = simd_max(max(M[j], s)); + } - const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); - const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + const half ms = exp(m - M[j]); - S[j] = S[j]*ms + simd_sum(vs); + S[j] = S[j]*ms; // create an 8x8 diagonal matrix for rescaling the output - if (p == j) { + if (tiisg == j) { ss[j*T + C + j] = ms; } - // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; + for (int64_t p = tiisg; p < C; p += NW) { + const half s = ss[j*T + p]; + + const half vs = exp(s - M[j]); + + S[j] = S[j] + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = vs; + } } // skip -INF blocks @@ -2231,7 +2235,7 @@ kernel void kernel_flash_attn_ext_f16( threadgroup_barrier(mem_flags::mem_threadgroup); - // each simdgroup stores its output to shared memory, reusing sq4 + // each simdgroup stores its output to shared memory, reusing sq if (sgitg == sg) { for (int64_t i = 0; i < D8; ++i) { simdgroup_store(lo[i], sq + i*8, T, 0, false); @@ -2284,7 +2288,7 @@ kernel void kernel_flash_attn_ext_f16( } } - // store result to shared memory (reuse sq4) + // store result to shared memory (reuse sq) if (sgitg == 0) { for (int64_t i = 0; i < D8; ++i) { simdgroup_store(lo[i], sq + i*8, T, 0, false); @@ -2298,8 +2302,8 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { const half S = ss[j*T + 0]; - for (int64_t i = 0; i < L4; ++i) { - dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + NW*i + tiisg] = (float4) sq4[j*T4 + NW*i + tiisg]/S; + for (int64_t i = tiisg; i < D4; i += NW) { + dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; } } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 4c98bef7cf3a6..4093a52f2eef1 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1395,7 +1395,7 @@ struct test_flash_attn_ext : public test_case { } double max_nmse_err() override { - return 5e-5; + return 5e-4; } test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) @@ -1677,9 +1677,15 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 8)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 7)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 1)); + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 8)); + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 7)); + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 1)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 8)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 7)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 1)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 8)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 7)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 1)); #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer From 3a428a10973a751af72b55b9ef396de9c305c6ac Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 28 Jan 2024 17:47:22 +0200 Subject: [PATCH 31/58] metal : improve precision --- ggml-metal.metal | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 7b604eb61a177..b6b5fd997b93a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2120,7 +2120,7 @@ kernel void kernel_flash_attn_ext_f16( device const float * mp = (device const float *) (mask + (ir%ne31)*nb31); // prepare diagonal scale matrix - simdgroup_half8x8 mscale(scale); + simdgroup_float8x8 mscale(scale); // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns @@ -2163,7 +2163,7 @@ kernel void kernel_flash_attn_ext_f16( M[j] = simd_max(max(M[j], s)); } - const half ms = exp(m - M[j]); + const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); S[j] = S[j]*ms; @@ -2175,7 +2175,7 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t p = tiisg; p < C; p += NW) { const half s = ss[j*T + p]; - const half vs = exp(s - M[j]); + const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); S[j] = S[j] + simd_sum(vs); @@ -2255,8 +2255,8 @@ kernel void kernel_flash_attn_ext_f16( M = max(M0, M1); - const half ms0 = exp(M0 - M); - const half ms1 = exp(M1 - M); + const half ms0 = M0 == -INFINITY ? 0.0h : exp(M0 - M); + const half ms1 = M1 == -INFINITY ? 0.0h : exp(M1 - M); S = S0*ms0 + S1*ms1; From 8612864108760897261d0d10101f68355899b03f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 28 Jan 2024 18:10:16 +0200 Subject: [PATCH 32/58] ggml : fix f16 mad --- ggml.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml.c b/ggml.c index 6bba840d93d0c..fc0886aecf5a1 100644 --- a/ggml.c +++ b/ggml.c @@ -1344,12 +1344,12 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const // leftovers for (int i = np; i < n; ++i) { - y[i] += GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i])*v); + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); } #else // scalar for (int i = 0; i < n; ++i) { - y[i] += GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i])*v); + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); } #endif } From 134c81c78dfdeaca988ea2505cc6f0c0aec2d243 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 28 Jan 2024 22:23:40 +0200 Subject: [PATCH 33/58] metal : minor --- ggml-metal.metal | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index b6b5fd997b93a..ad6a4a318f4c3 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2127,15 +2127,14 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t ic = C*sgitg; ic < ne11; ic += C*nsg) { // Q*K^T { - simdgroup_half8x8 mk; - for (int cc = 0; cc < C/8; ++cc) { simdgroup_half8x8 mqk = make_filled_simdgroup_matrix(0.h); device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); for (int64_t i = 0; i < D8; ++i) { - simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); + simdgroup_half8x8 mk; + simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); } @@ -2192,7 +2191,6 @@ kernel void kernel_flash_attn_ext_f16( // O = diag(ms)*O { simdgroup_half8x8 mm; - simdgroup_load(mm, ss + C, T, 0, false); for (int64_t i = 0; i < D8; ++i) { @@ -2202,8 +2200,6 @@ kernel void kernel_flash_attn_ext_f16( // O = O + (Q*K^T)*V { - simdgroup_half8x8 mv; - for (int cc = 0; cc < C/8; ++cc) { simdgroup_half8x8 mp; simdgroup_load(mp, ss + 8*cc, T, 0, false); @@ -2211,6 +2207,7 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t i = 0; i < D8; ++i) { device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + simdgroup_half8x8 mv; simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false); simdgroup_multiply_accumulate(lo[i], mp, mv, lo[i]); From 1db22d7032fd55a612e400164cb70ad238bbc055 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 28 Jan 2024 23:08:31 +0200 Subject: [PATCH 34/58] metal : support Q > 8 --- examples/batched-bench/batched-bench.cpp | 2 +- ggml-metal.m | 7 ++- ggml-metal.metal | 80 +++++++++++++++--------- 3 files changed, 55 insertions(+), 34 deletions(-) diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 7924db267401c..4992b57f6f9db 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -104,7 +104,7 @@ int main(int argc, char ** argv) { ctx_params.seed = 1234; ctx_params.n_ctx = n_kv_max; - ctx_params.n_batch = 512; + ctx_params.n_batch = 2048; ctx_params.mul_mat_q = mmq; ctx_params.n_threads = params.n_threads; diff --git a/ggml-metal.m b/ggml-metal.m index ef799ef57b643..a0dd1d0df5bcb 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2206,8 +2206,11 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! (multiple of 8) - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! (multiple of 32) + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + + GGML_ASSERT(nqptg % 8 == 0); + GGML_ASSERT(ncpsg % 32 == 0); // simdgroups per threadgroup (a.k.a. warps) // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) diff --git a/ggml-metal.metal b/ggml-metal.metal index ad6a4a318f4c3..08c000cc4c027 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2040,6 +2040,7 @@ kernel void kernel_flash_attn_ext_f16( const int64_t D4 = D/4; const int64_t D8 = D/8; + const int64_t Q8 = Q/8; const int64_t NW = N_SIMDWIDTH; const int64_t SH = (C + Q); // shared memory per simdgroup in (half) @@ -2051,7 +2052,7 @@ kernel void kernel_flash_attn_ext_f16( threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for diagonal matrix // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - simdgroup_half8x8 lo[D8]; + simdgroup_half8x8 lo[Q8][D8]; // load heads from Q to shared memory for (int64_t j = sgitg; j < Q; j += nsg) { @@ -2067,8 +2068,10 @@ kernel void kernel_flash_attn_ext_f16( } // zero out lo - for (int64_t i = 0; i < D8; ++i) { - lo[i] = make_filled_simdgroup_matrix(0.0h); + for (int64_t j = 0; j < Q8; ++j) { + for (int64_t i = 0; i < D8; ++i) { + lo[j][i] = make_filled_simdgroup_matrix(0.0h); + } } // zero out shared memory SH @@ -2108,10 +2111,12 @@ kernel void kernel_flash_attn_ext_f16( const int64_t iv3 = iq3 / rv3; // load the queries from shared memory into local memory - simdgroup_half8x8 mq[D8]; + simdgroup_half8x8 mq[Q8][D8]; - for (int64_t i = 0; i < D8; ++i) { - simdgroup_load(mq[i], sq + i*8, T); + for (int64_t j = 0; j < Q8; ++j) { + for (int64_t i = 0; i < D8; ++i) { + simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T); + } } const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; @@ -2128,7 +2133,10 @@ kernel void kernel_flash_attn_ext_f16( // Q*K^T { for (int cc = 0; cc < C/8; ++cc) { - simdgroup_half8x8 mqk = make_filled_simdgroup_matrix(0.h); + simdgroup_half8x8 mqk[Q8]; + for (int64_t j = 0; j < Q8; ++j) { + mqk[j] = make_filled_simdgroup_matrix(0.h); + } device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); @@ -2136,15 +2144,19 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_half8x8 mk; simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose - simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); + for (int64_t j = 0; j < Q8; ++j) { + simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]); + } } // mqk = mqk*scale + mask - simdgroup_float8x8 mm; - simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(float), 0, false); - simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); + for (int64_t j = 0; j < Q8; ++j) { + simdgroup_float8x8 mm; + simdgroup_load(mm, mp + 8*j*(nb31/sizeof(float)) + ic + 8*cc, nb31/sizeof(float), 0, false); + simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm); - simdgroup_store(mqk, ss + 8*cc, T, 0, false); + simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false); + } } } @@ -2166,7 +2178,7 @@ kernel void kernel_flash_attn_ext_f16( S[j] = S[j]*ms; - // create an 8x8 diagonal matrix for rescaling the output + // create a QxQ diagonal matrix for rescaling the output if (tiisg == j) { ss[j*T + C + j] = ms; } @@ -2189,28 +2201,30 @@ kernel void kernel_flash_attn_ext_f16( } // O = diag(ms)*O - { + for (int64_t j = 0; j < Q8; ++j) { simdgroup_half8x8 mm; - simdgroup_load(mm, ss + C, T, 0, false); + simdgroup_load(mm, ss + 8*j*T + C + 8*j, T, 0, false); for (int64_t i = 0; i < D8; ++i) { - simdgroup_multiply(lo[i], mm, lo[i]); + simdgroup_multiply(lo[j][i], mm, lo[j][i]); } } // O = O + (Q*K^T)*V { for (int cc = 0; cc < C/8; ++cc) { - simdgroup_half8x8 mp; - simdgroup_load(mp, ss + 8*cc, T, 0, false); + device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); for (int64_t i = 0; i < D8; ++i) { - device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + simdgroup_half8x8 mk; + simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); - simdgroup_half8x8 mv; - simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false); + for (int64_t j = 0; j < Q8; ++j) { + simdgroup_half8x8 mv; + simdgroup_load(mv, ss + 8*j*T + 8*cc, T, 0, false); - simdgroup_multiply_accumulate(lo[i], mp, mv, lo[i]); + simdgroup_multiply_accumulate(lo[j][i], mv, mk, lo[j][i]); + } } } } @@ -2234,8 +2248,10 @@ kernel void kernel_flash_attn_ext_f16( // each simdgroup stores its output to shared memory, reusing sq if (sgitg == sg) { - for (int64_t i = 0; i < D8; ++i) { - simdgroup_store(lo[i], sq + i*8, T, 0, false); + for (int64_t j = 0; j < Q8; ++j) { + for (int64_t i = 0; i < D8; ++i) { + simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); + } } } @@ -2267,19 +2283,19 @@ kernel void kernel_flash_attn_ext_f16( } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - { + for (int64_t j = 0; j < Q8; ++j) { simdgroup_half8x8 t; simdgroup_half8x8 ms0; simdgroup_half8x8 ms1; - simdgroup_load(ms0, ss + C, T, 0, false); - simdgroup_load(ms1, ss + C + sg*SH, T, 0, false); + simdgroup_load(ms0, ss + 8*j*T + C + 8*j, T, 0, false); + simdgroup_load(ms1, ss + 8*j*T + C + 8*j + sg*SH, T, 0, false); for (int64_t i = 0; i < D8; ++i) { - simdgroup_load (t, sq + i*8, T, 0, false); + simdgroup_load (t, sq + 8*j*T + i*8, T, 0, false); simdgroup_multiply(t, ms1, t); - simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); + simdgroup_multiply_accumulate(lo[j][i], ms0, lo[j][i], t); } } } @@ -2287,8 +2303,10 @@ kernel void kernel_flash_attn_ext_f16( // store result to shared memory (reuse sq) if (sgitg == 0) { - for (int64_t i = 0; i < D8; ++i) { - simdgroup_store(lo[i], sq + i*8, T, 0, false); + for (int64_t j = 0; j < Q8; ++j) { + for (int64_t i = 0; i < D8; ++i) { + simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); + } } } From 4794821a31d5778b3398b8375d29fa63a539c8c4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 29 Jan 2024 16:44:55 +0200 Subject: [PATCH 35/58] tests : add ATTN tests --- tests/test-backend-ops.cpp | 70 +++++++++++++++++++++++++++++++++----- 1 file changed, 61 insertions(+), 9 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index c18ff07ea4d21..0ce498e9e7dd4 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1418,6 +1418,48 @@ struct test_flash_attn_ext : public test_case { } }; +// Attention +struct test_attn : public test_case { + const int64_t hs; // head size + const int64_t nh; // num heads + const int64_t kv; // kv size + const int64_t nb; // batch size + + std::string op_desc(ggml_tensor * t) override { + return "ATTN"; + + GGML_UNUSED(t); + } + + std::string vars() override { + return VARS_TO_STR4(hs, nh, kv, nb); + } + + double max_nmse_err() override { + return 5e-4; + } + + test_attn(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) + : hs(hs), nh(nh), kv(kv), nb(nb) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); + ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); + ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); // transposed + ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1); + + struct ggml_tensor * cur; + + cur = ggml_mul_mat (ctx, k, q); + cur = ggml_soft_max_ext(ctx, cur, mask, 1.0f/sqrtf(hs)); + cur = ggml_mul_mat (ctx, v, cur); + cur = ggml_permute (ctx, cur, 0, 2, 1, 3); + cur = ggml_cont_2d (ctx, cur, hs*nh, nb); + + return cur; + } +}; + // Mixtral MOE struct test_moe : public test_case { const int n_experts; @@ -1684,15 +1726,25 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 8)); - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 7)); - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 1)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 8)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 7)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 1)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 8)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 7)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 1)); + test_cases.emplace_back(new test_attn(64, 32, 512, 8)); + test_cases.emplace_back(new test_attn(64, 32, 512, 7)); + test_cases.emplace_back(new test_attn(64, 32, 512, 1)); + test_cases.emplace_back(new test_attn(80, 32, 512, 8)); + test_cases.emplace_back(new test_attn(80, 32, 512, 7)); + test_cases.emplace_back(new test_attn(80, 32, 512, 1)); + test_cases.emplace_back(new test_attn(128, 32, 512, 8)); + test_cases.emplace_back(new test_attn(128, 32, 512, 7)); + test_cases.emplace_back(new test_attn(128, 32, 512, 1)); + + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 8)); + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 7)); + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 1)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 8)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 7)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 1)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 8)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 7)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 1)); #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer From abeaf0d90ee82096a0aba20785f1e37bd1f3aa41 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 29 Jan 2024 18:12:24 +0200 Subject: [PATCH 36/58] metal : disable buffer allocation logs --- ggml-metal.m | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index a0dd1d0df5bcb..a637f04875dbe 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2421,10 +2421,13 @@ GGML_CALL static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buff UNUSED(buft); } -static void ggml_backend_metal_log_allocated_size(id device) { +static void ggml_backend_metal_log_allocated_size(id device, size_t size_aligned) { +#ifndef GGML_METAL_NDEBUG #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) if (@available(macOS 10.12, iOS 16.0, *)) { - GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)", + GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)", + __func__, + size_aligned / 1024.0 / 1024.0, device.currentAllocatedSize / 1024.0 / 1024.0, device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); @@ -2434,10 +2437,15 @@ static void ggml_backend_metal_log_allocated_size(id device) { GGML_METAL_LOG_INFO("\n"); } } else { - GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0); + GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n", + __func__, + size_aligned / 1024.0 / 1024.0, + device.currentAllocatedSize / 1024.0 / 1024.0); } +#endif #endif UNUSED(device); + UNUSED(size_aligned); } GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { @@ -2471,8 +2479,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buff return NULL; } - GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0); - ggml_backend_metal_log_allocated_size(device); + ggml_backend_metal_log_allocated_size(device, size_aligned); return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size); } @@ -2549,7 +2556,7 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, return false; } - GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0); + ggml_backend_metal_log_allocated_size(device, size_aligned); ++ctx->n_buffers; } else { @@ -2572,7 +2579,8 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, return false; } - GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, offs = %12ld", __func__, size_step_aligned / 1024.0 / 1024.0, i); + ggml_backend_metal_log_allocated_size(device, size_step_aligned); + if (i + size_step < size) { GGML_METAL_LOG_INFO("\n"); } @@ -2581,8 +2589,6 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, } } - ggml_backend_metal_log_allocated_size(device); - return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size); } From c6c1132e5e6658b3c209433ed5ef75067ef31a2f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 29 Jan 2024 18:22:28 +0200 Subject: [PATCH 37/58] tests : more --- ggml-metal.m | 9 +++++++++ ggml-metal.metal | 3 +++ ggml.c | 5 ----- tests/test-backend-ops.cpp | 29 ++++++++++------------------- 4 files changed, 22 insertions(+), 24 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index a637f04875dbe..4b5fd0bb8fc58 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -137,7 +137,10 @@ GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, @@ -505,7 +508,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); @@ -2166,7 +2172,10 @@ static bool ggml_metal_graph_compute( switch (ne00) { case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; default: { GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); diff --git a/ggml-metal.metal b/ggml-metal.metal index 08c000cc4c027..be059d78f505a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2326,7 +2326,10 @@ kernel void kernel_flash_attn_ext_f16( template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>; kernel void kernel_cpy_f16_f16( device const half * src0, diff --git a/ggml.c b/ggml.c index e8a5fcfa485c1..57271a1ad43e3 100644 --- a/ggml.c +++ b/ggml.c @@ -13554,11 +13554,9 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int64_t D = neq0; const int64_t N = neq1; - const int64_t P = nek1 - N; GGML_ASSERT(ne0 == D); GGML_ASSERT(ne2 == N); - GGML_ASSERT(P >= 0); GGML_ASSERT(nbq0 == sizeof(float)); GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); @@ -13569,7 +13567,6 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(nev0 == D); GGML_ASSERT(neq1 == N); - GGML_ASSERT(nek1 == N + P); GGML_ASSERT(nev0 == D); // dst cannot be transposed or permuted @@ -13608,8 +13605,6 @@ static void ggml_compute_forward_flash_attn_ext_f16( float scale = 1.0f; memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); - //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); - // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { // q indices diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 0ce498e9e7dd4..f57e8ab1a853e 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1726,25 +1726,16 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); - test_cases.emplace_back(new test_attn(64, 32, 512, 8)); - test_cases.emplace_back(new test_attn(64, 32, 512, 7)); - test_cases.emplace_back(new test_attn(64, 32, 512, 1)); - test_cases.emplace_back(new test_attn(80, 32, 512, 8)); - test_cases.emplace_back(new test_attn(80, 32, 512, 7)); - test_cases.emplace_back(new test_attn(80, 32, 512, 1)); - test_cases.emplace_back(new test_attn(128, 32, 512, 8)); - test_cases.emplace_back(new test_attn(128, 32, 512, 7)); - test_cases.emplace_back(new test_attn(128, 32, 512, 1)); - - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 8)); - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 7)); - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 1)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 8)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 7)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 1)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 8)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 7)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 1)); + for (int hs : { 64, 80, 96, 112, 128, 256, }) { + for (int nh : { 32, }) { + for (int kv : { 512, 1024, 2048, 4096, }) { + for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) { + test_cases.emplace_back(new test_attn (hs, nh, kv, nb)); + test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb)); + } + } + } + } #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer From 5fcb9c1c5af108056c8ad51fc1995de9d7707d2f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 29 Jan 2024 19:46:22 +0200 Subject: [PATCH 38/58] metal : faster inner loop for C == 32 --- ggml-metal.metal | 59 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 42 insertions(+), 17 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index be059d78f505a..db4c7cfde0037 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2048,8 +2048,8 @@ kernel void kernel_flash_attn_ext_f16( const int64_t T4 = T/4; // shared memory size per query in (half4) threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data - threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // scratch buffer for attention - threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for diagonal matrix + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) simdgroup_half8x8 lo[Q8][D8]; @@ -2164,34 +2164,59 @@ kernel void kernel_flash_attn_ext_f16( half smax = -INFINITY; // online softmax - for (int64_t j = 0; j < Q; ++j) { - const half m = M[j]; + if (C == 32) { + for (int64_t j = 0; j < Q; ++j) { + const int64_t p = tiisg; - for (int64_t p = tiisg; p < C; p += NW) { + const half m = M[j]; const half s = ss[j*T + p]; smax = simd_max(max(smax, s)); M[j] = simd_max(max(M[j], s)); - } - const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); + const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); + const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); - S[j] = S[j]*ms; + S[j] = S[j]*ms + simd_sum(vs); + + // create a QxQ diagonal matrix for rescaling the output + if (p == j) { + ss[j*T + C + j] = ms; + } - // create a QxQ diagonal matrix for rescaling the output - if (tiisg == j) { - ss[j*T + C + j] = ms; + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = vs; } + } else { + for (int64_t j = 0; j < Q; ++j) { + const half m = M[j]; - for (int64_t p = tiisg; p < C; p += NW) { - const half s = ss[j*T + p]; + for (int64_t p = tiisg; p < C; p += NW) { + const half s = ss[j*T + p]; - const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + smax = simd_max(max(smax, s)); + M[j] = simd_max(max(M[j], s)); + } - S[j] = S[j] + simd_sum(vs); + const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); - // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; + S[j] = S[j]*ms; + + // create a QxQ diagonal matrix for rescaling the output + if (tiisg == j) { + ss[j*T + C + j] = ms; + } + + for (int64_t p = tiisg; p < C; p += NW) { + const half s = ss[j*T + p]; + + const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + + S[j] = S[j] + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = vs; + } } } From a1d5a12bc5ab9cde4d3db304d3882b99cca5e849 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Mon, 29 Jan 2024 13:15:33 -0500 Subject: [PATCH 39/58] fix compiler error --- ggml-cuda.cu | 48 +++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 41 insertions(+), 7 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index ecfa98c4e4763..8fa21c97e031c 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6134,6 +6134,7 @@ static __global__ void flash_attn_f32( } } +#if __CUDA_ARCH__ >= CC_VOLTA typedef nvcuda::wmma::fragment half16x16_a; typedef nvcuda::wmma::fragment half16x16_b; typedef nvcuda::wmma::fragment half16x16_acc; @@ -6185,13 +6186,13 @@ static __global__ void flash_attn_ext_f16( const half scale_h = __float2half(scale); - extern __shared__ char data_flash_attn_shmem[]; + extern __shared__ half __flash_attn_f16_shmem[]; // pq - half * pq = (half *) (data_flash_attn_shmem + 0*D); - half2 * pq2 = (half2 *) (data_flash_attn_shmem + 0*D); - half * ps = (half *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 1*D); - half2 * ps2 = (half2 *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 1*D); - half * ss = (half *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 2*D); + half * pq = (half *) (__flash_attn_f16_shmem + 0*D); + half2 * pq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); + half * ps = (half *) (__flash_attn_f16_shmem + warp_id*(D + 1*C) + 1*D); + half2 * ps2 = (half2 *) (__flash_attn_f16_shmem + warp_id*(D + 1*C) + 1*D); + half * ss = (half *) (__flash_attn_f16_shmem + warp_id*(D + 1*C) + 2*D); for (int i = 0; i < L2; ++i) { // load heads from Q to shared memory @@ -6217,7 +6218,7 @@ static __global__ void flash_attn_ext_f16( } __syncthreads(); - +#if 0 { half S[Q] = { 0.0 }; // could be half2 S[Q/2] = how fill this array with zeros?? half M[Q] = { -INFINITY }; // could be half2 M[Q/2] = better register utilization @@ -6400,7 +6401,40 @@ static __global__ void flash_attn_ext_f16( } } } +#endif } +#else +template // D head size, Q queries per block, C cache items per blocks +static __global__ void flash_attn_ext_f16( + const char* __restrict__ q, + const char* __restrict__ k, + const char* __restrict__ v, + const char* __restrict__ mask, + float* __restrict__ kqv, + float scale, + int ne00, + int ne01, + int ne02, + int ne03, + int ne10, + int ne11, + int ne12, + int ne13, + int ne31, + int nb31, + int nb01, + int nb02, + int nb03, + int nb11, + int nb12, + int nb13, + int ne0, + int ne1, + int ne2, + int ne3) { + bad_arch(); + } +#endif template static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, From d073e4f93337560e552f0d3de4b6b07bf13ef3f5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 30 Jan 2024 21:45:32 +0200 Subject: [PATCH 40/58] metal : fix array initialization --- ggml-metal.metal | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index db4c7cfde0037..41f6169de8abd 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2084,8 +2084,8 @@ kernel void kernel_flash_attn_ext_f16( threadgroup_barrier(mem_flags::mem_threadgroup); { - half S[Q] = { 0.0h }; - half M[Q] = { -INFINITY }; + half S[Q] = { [0 ... Q-1] = 0.0h }; + half M[Q] = { [0 ... Q-1] = -INFINITY }; // assume K and V are same shape const int64_t ne22 = ne12; From 78df5527e4e9eafb181200384fbed80c8116042e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 30 Jan 2024 21:46:49 +0200 Subject: [PATCH 41/58] tests : ifdef --- tests/test-backend-ops.cpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index f57e8ab1a853e..07182c6d8aa63 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1726,6 +1726,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); +#if 0 for (int hs : { 64, 80, 96, 112, 128, 256, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, 2048, 4096, }) { @@ -1736,6 +1737,18 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } } +#else + for (int hs : { 128, }) { + for (int nh : { 32, }) { + for (int kv : { 512, 1024, }) { + for (int nb : { 1, 2, 4, 8, 512 }) { + test_cases.emplace_back(new test_attn (hs, nh, kv, nb)); + test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb)); + } + } + } + } +#endif #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer From 3b0f74b42859e557ed6def59aab98bdcff8f913a Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Tue, 30 Jan 2024 14:57:12 -0500 Subject: [PATCH 42/58] latest kernel update, wrong values --- ggml-cuda.cu | 362 ++++++++++++++++++++++----------- tests/test-flash-attention.cpp | 105 ++++------ 2 files changed, 285 insertions(+), 182 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 5f6438048e040..5229e15d2774a 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -125,6 +125,11 @@ #include "ggml.h" #include "ggml-backend-impl.h" +#undef MIN +#undef MAX +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + #define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed) #define CC_PASCAL 600 @@ -679,7 +684,6 @@ static __device__ __forceinline__ half warp_reduce_max(half x) { return x; #else (void) x; - bad_arch(); #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX } @@ -6156,16 +6160,17 @@ static __global__ void flash_attn_f32( #if __CUDA_ARCH__ >= CC_VOLTA typedef nvcuda::wmma::fragment half16x16_a; typedef nvcuda::wmma::fragment half16x16_b; +typedef nvcuda::wmma::fragment half16x16_bT; typedef nvcuda::wmma::fragment half16x16_acc; // based on metal version -template // D head size, Q queries per block, C cache items per blocks +template // D head size, Q queries per block, C cache items per block static __global__ void flash_attn_ext_f16( const char* __restrict__ q, const char* __restrict__ k, const char* __restrict__ v, const char* __restrict__ mask, - float* __restrict__ kqv, + float* __restrict__ dst, float scale, int ne00, int ne01, @@ -6190,57 +6195,64 @@ static __global__ void flash_attn_ext_f16( const int warp_id = threadIdx.y; const int lane_id = threadIdx.x; - const int n_warps = blockDim.y; // number of warps + const int num_warps = blockDim.y; // number of warps const int iq3 = blockIdx.z; const int iq2 = blockIdx.y; const int iq1 = blockIdx.x * Q; const int D2 = D/2; - const int N4 = WARP_SIZE; - const int L2 = (D2 + N4 - 1)/N4; const int D16 = D/16; + const int Q16 = Q/16; + const int NW = WARP_SIZE; + const int SH = (C + D); // shared memory per simdgroup in (half) - const int T = D + n_warps*(D + 1*C); // shared memory size per query in half - const int T2 = T/2; // shared memory size per query in half2 - - const half scale_h = __float2half(scale); + const int T = D + num_warps*SH; // shared memory size per query in (half) + const int T2 = T/2; // shared memory size per query in (half2) extern __shared__ half __flash_attn_f16_shmem[]; // pq - half * pq = (half *) (__flash_attn_f16_shmem + 0*D); - half2 * pq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); - half * ps = (half *) (__flash_attn_f16_shmem + warp_id*(D + 1*C) + 1*D); - half2 * ps2 = (half2 *) (__flash_attn_f16_shmem + warp_id*(D + 1*C) + 1*D); - half * ss = (half *) (__flash_attn_f16_shmem + warp_id*(D + 1*C) + 2*D); - - for (int i = 0; i < L2; ++i) { - // load heads from Q to shared memory - for (int j = warp_id; j < Q; j += n_warps) { - const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data + half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2 + half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix + half16x16_acc lo[Q16][D16]; + + // load heads from Q to shared memory + for (int64_t j = warp_id; j < Q; j += num_warps) { + const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + + for (int64_t i = lane_id; i < D2; i += NW) { if (iq1 + j < ne01) { - pq2[j*T2 + N4*i + lane_id] = __float22half2_rn(q2[N4*i + lane_id]); + sq2[j*T2 + i] = __float22half2_rn(q2[i]); } else { - pq2[j*T2 + N4*i + lane_id] = make_half2(0.0, 0.0); + sq2[j*T2 + i] = make_half2(0.0, 0.0); } } + } - // zero out shared memory - for (int j = 0; j < Q; ++j) { - ps2[j*T2 + N4*i + lane_id] = make_half2(0.0, 0.0); + // zero out lo + for (int64_t j = 0; j < Q16; ++j) { + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::fill_fragment(lo[j][i], 0.0); } } - if (lane_id < C) { - for (int j = 0; j < Q; ++j) { - ss[j*T + 0 + lane_id] = 0.0; + // zero out shared memory SH + for (int64_t j = 0; j < Q; ++j) { + for (int64_t i = lane_id; i < SH; i += NW) { + ss[j*T + i] = 0.0; } } __syncthreads(); -#if 0 + { - half S[Q] = { 0.0 }; // could be half2 S[Q/2] = how fill this array with zeros?? - half M[Q] = { -INFINITY }; // could be half2 M[Q/2] = better register utilization + float S[Q]; + float M[Q]; + + for(int i = 0; i < Q;i ++) { + S[i] = 0.0f; + M[i] = -INFINITY; + } // assume K and V are same shape const int ne22 = ne12; @@ -6265,162 +6277,252 @@ static __global__ void flash_attn_ext_f16( const int iv2 = iq2 / rv2; const int iv3 = iq3 / rv3; - // TODO: this can be improved - float * mp[Q]; - - { - const int ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - - for (int j = 0; j < Q; ++j) { - if (iq1 + j < ne01) { - mp[j] = (float *)(mask + ((ir + j)%ne31) * nb31); - } else { - mp[j] = nullptr; - } + // load the queries from shared memory into local memory + half16x16_a mq[Q16][D16]; + for (int64_t j = 0; j < Q16; ++j) { + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T); } } - for (int iic = C*warp_id; iic < ne11; iic += C*n_warps) { - // skip -INF blocks - // TODO: double-check this - { - float smc = -INFINITY; - - for (int j = 0; j < Q; ++j) { - const float mc = mp[j] ? mp[j][iic + lane_id] : -INFINITY; - smc = warp_reduce_max(max(smc, mc)); - } + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - if (smc == -INFINITY) { - continue; - } - } + // pointer to the mask + const float * mp = (const float *) (mask + (ir%ne31)*nb31); + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) { // Q*K^T { - half16x16_a mq; - half16x16_b mk; - half16x16_acc mqk; - for (int cc = 0; cc < C/16; ++cc) { - nvcuda::wmma::fill_fragment(mqk, 0); + half16x16_acc mqk[Q16]; + for (int64_t j = 0; j < Q16; ++j) { + nvcuda::wmma::fill_fragment(mqk[j], 0); + } + + const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13)); - const half * pk = (const half *) (k + ((iic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13)); + for (int64_t i = 0; i < D16; ++i) { + half16x16_bT mk; // transposed key + nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); // transpose - for(int i = 0; i < D16;i ++) { - nvcuda::wmma::load_matrix_sync(mq, pq + i*16, T); - nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); - nvcuda::wmma::mma_sync(mqk, mq, mk, mqk); + for (int64_t j = 0; j < Q16; ++j) { + nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]); + } } - nvcuda::wmma::store_matrix_sync(ss + 16*cc, mqk, T, nvcuda::wmma::mem_col_major); + // mqk = mqk*scale + mask + for (int64_t j = 0; j < Q16; ++j) { + const float* msk_p = mp + 16*j*(nb31/sizeof(float)) + ic + 16*cc; + int64_t msk_ne_row = nb31/sizeof(float); + for (uint32_t i = 0; i < mqk[j].num_elements; i++) { + int msk_col = i % 16; + int msk_row = i / 16; + mqk[j].x[i] = __float2half(scale * __half2float(mqk[j].x[i]) + msk_p[msk_col + msk_row*msk_ne_row]); + } + nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_col_major); + } } } + // used to detect blocks full of -INF + float smax = -INFINITY; + // online softmax - for (int j = 0; j < Q; ++j) { - const int p = lane_id; + if (C == 32) { + for (int64_t j = 0; j < Q; ++j) { + const int64_t p = lane_id; - const half s = ss[j*T + p]*scale_h + __float2half(mp[j][iic + p]); + const float m = M[j]; + const float s = __half2float(ss[j*T + p]); - half m = M[j]; + smax = warp_reduce_max(max(smax, s)); + M[j] = warp_reduce_max(max(M[j], s)); - M[j] = warp_reduce_max(__hmax(M[j], s)); + const float ms = m == -INFINITY ? 0.0f : expf(m - M[j]); + const float vs = s == -INFINITY ? 0.0f : expf(s - M[j]); - const half ms = __hisinf(m) == -1 ? 0.0 : hexp(m - M[j]); - const half vs = __hisinf(s) == -1 ? 0.0 : hexp(s - M[j]); + S[j] = S[j]*ms + warp_reduce_sum(vs); - S[j] = S[j]*ms + warp_reduce_sum(vs); + // create a QxQ diagonal matrix for rescaling the output + if (p == j) { + ss[j*T + C + j] = __float2half(ms); + } - for (int i = 0; i < L2; ++i) { - ps2[j*T2 + N4*i + lane_id] *= __half2half2(ms); + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = __float2half(vs); } + } else { + for (int64_t j = 0; j < Q; ++j) { + const float m = M[j]; + + for (int64_t p = lane_id; p < C; p += NW) { + const float s = __half2float(ss[j*T + p]); + + smax = warp_reduce_max(max(smax, s)); + M[j] = warp_reduce_max(max(M[j], s)); + } + + const float ms = m == -INFINITY ? 0.0f : expf(m - M[j]); - ss[j*T + p] = vs; + S[j] = S[j]*ms; + + // create a QxQ diagonal matrix for rescaling the output + if (lane_id == j) { + ss[j*T + C + j] = ms; + } + + for (int64_t p = lane_id; p < C; p += NW) { + const float s = ss[j*T + p]; + + const float vs = s == -INFINITY ? 0.0f : expf(s - M[j]); + + S[j] = S[j] + warp_reduce_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = __float2half(vs); + } + } } - __syncthreads(); + // skip -INF blocks + if (smax == -INFINITY) { + continue; + } - // (Q*K^T)*V - { - half16x16_acc mqkv; - half16x16_a mqk; - half16x16_b mv; + // O = diag(ms)*O + for (int64_t j = 0; j < Q16; ++j) { + half16x16_a mm; + half16x16_b zro; + + nvcuda::wmma::fill_fragment(zro, 0.0); + nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); for (int64_t i = 0; i < D16; ++i) { - nvcuda::wmma::fill_fragment(mqkv, 0); + nvcuda::wmma::mma_sync(lo[j][i], mm, zro, lo[j][i]); + } + } + + // O = O + (Q*K^T)*V + { + for (int cc = 0; cc < C/16; ++cc) { + const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); - for (int cc = 0; cc < C/16; ++cc) { - const half * pv = (const half *) ((const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + for (int64_t i = 0; i < D16; ++i) { + half16x16_b mk; + nvcuda::wmma::load_matrix_sync(mk, pv + i*16, nb21/sizeof(half)); - nvcuda::wmma::load_matrix_sync(mqk, ss + cc*16, T); - nvcuda::wmma::load_matrix_sync(mv, pv + i*16, nb21/sizeof(half)); + for (int64_t j = 0; j < Q16; ++j) { + half16x16_a mv; + nvcuda::wmma::load_matrix_sync(mv, ss + 16*j*T + 16*cc, T); - nvcuda::wmma::mma_sync(mqkv, mqk, mv, mqkv); + nvcuda::wmma::mma_sync(lo[j][i], mv, mk, lo[j][i]); + } } - - nvcuda::wmma::store_matrix_sync(ps + i*16, mqkv, T, nvcuda::wmma::mem_col_major); } } } + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) for (int64_t j = 0; j < Q; ++j) { if (lane_id == 0) { - ss[j*T + 0] = S[j]; - ss[j*T + 1] = M[j]; + ss[j*T + 0] = __float2half(S[j]); + ss[j*T + 1] = __float2half(M[j]); } } } - __syncthreads(); + // reduce the warps sequentially + for (int64_t sg = 1; sg < num_warps; ++sg) { + float S = 0.0f; + float M = -INFINITY; - // reduce the warps - // TODO: try parallel reduce - if (warp_id == 0) { - half S = 0.0; - half M = __float2half(-INFINITY); + __syncthreads(); - for (int64_t sg = 1; sg < n_warps; ++sg) { + // each simdgroup stores its output to shared memory, reusing sq + if (warp_id == sg) { + for (int64_t j = 0; j < Q16; ++j) { + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_col_major); + } + } + } + + __syncthreads(); + + // the first simdgroup accumulates the results from the other simdgroups + if (warp_id == 0) { for (int64_t j = 0; j < Q; ++j) { - const half S0 = ss[j*T + 0]; - const half S1 = ss[j*T + sg*(D + 1*C) + 0]; + const float S0 = __half2float(ss[j*T + 0]); + const float S1 = __half2float(ss[j*T + sg*SH + 0]); - const half M0 = ss[j*T + 1]; - const half M1 = ss[j*T + sg*(D + 1*C) + 1]; + const float M0 = __half2float(ss[j*T + 1]); + const float M1 = __half2float(ss[j*T + sg*SH + 1]); - M = __hmax(M0, M1); + M = max(M0, M1); - const half ms0 = hexp(M0 - M); - const half ms1 = hexp(M1 - M); + const float ms0 = M0 == -INFINITY ? 0.0f : expf(M0 - M); + const float ms1 = M1 == -INFINITY ? 0.0f : expf(M1 - M); S = S0*ms0 + S1*ms1; if (lane_id == 0) { - ss[j*T + 0] = S; - ss[j*T + 1] = M; + ss[j*T + 0] = __float2half(S); + ss[j*T + 1] = __float2half(M); + + ss[j*T + C + j ] = __float2half(ms0); + ss[j*T + C + j + sg*SH] = __float2half(ms1); } + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + for (int64_t j = 0; j < Q16; ++j) { + half16x16_a ms0; + half16x16_a ms1; + half16x16_b t; + half16x16_acc t2; + + nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T); + nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); - for (int64_t i = 0; i < L2; ++i) { - ps2[j*T2 + N4*i + lane_id] = ps2[j*T2 + N4*i + lane_id]*__half2half2(ms0) + ps2[j*T2 + sg*(D + 1*C)/4 + N4*i + lane_id]*__half2half2(ms1); + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); + nvcuda::wmma::mma_sync(t2, ms1, t, t2); + + // t <- lo + for (uint32_t k = 0; k < t.num_elements; k++) { + t.x[k] = lo[j][i].x[k]; + } + nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2); } } } } - __syncthreads(); + // store result to shared memory (reuse sq) + if (warp_id == 0) { + for (int64_t j = 0; j < Q16; ++j) { + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_col_major); + } + } + } - float2 * dst2 = (float2 *) kqv; + float2 * dst2 = (float2 *) dst; + // final rescale with 1/S and store to global memory if (warp_id == 0) { - for (int j = 0; j < Q && iq1 + j < ne01; ++j) { - half2 S = __half2half2(ss[j*T + 0]); + for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { + const float S = __half2float(ss[j*T + 0]); - for (int i = 0; i < L2; ++i) { - dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + N4*i + lane_id] = __half22float2(ps2[j*T2 + N4*i + lane_id]/S); + for (int64_t i = lane_id; i < D2; i += NW) { + dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + i] = __half22float2(sq2[j*T2 + i]); + dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + i].x /= S; + dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + i].y /= S; } } } -#endif } #else template // D head size, Q queries per block, C cache items per blocks @@ -6451,7 +6553,6 @@ static __global__ void flash_attn_ext_f16( int ne1, int ne2, int ne3) { - bad_arch(); } #endif @@ -10446,9 +10547,9 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * float scale; memcpy(&scale, KQV->op_params, sizeof(float)); - const int nwarps = Q->ne[1] < 4 ? 12 : 4; const int nqpb = 16; // queries per block const int ncpw = 32; // cache values per warp (does not work for other values) + const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, 32)) : 4; dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); dim3 block_dim(32, nwarps, 1); @@ -10457,6 +10558,23 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * printf("shared memory: %d bytes [%i, %i, %i]\n\n", shmem, Q->ne[0], Q->ne[1], Q->ne[2]); switch (Q->ne[0]) { + case 16: + flash_attn_ext_f16<16, 16, 32> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + (const char *) src3_extra->data_device[g_main_device], // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask->ne[1], mask->nb[1], + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; case 64: flash_attn_ext_f16<64, 16, 32> <<>> ( diff --git a/tests/test-flash-attention.cpp b/tests/test-flash-attention.cpp index 74167ed86fc84..5d83eeabd791a 100644 --- a/tests/test-flash-attention.cpp +++ b/tests/test-flash-attention.cpp @@ -2,8 +2,6 @@ #include "ggml-alloc.h" #include "ggml-backend.h" -#define GGML_USE_CUBLAS - #ifdef GGML_USE_CUBLAS #include "ggml-cuda.h" #endif @@ -22,6 +20,7 @@ struct test_model { struct ggml_tensor * q; struct ggml_tensor * k; struct ggml_tensor * v; + struct ggml_tensor * msk; ggml_backend_t backend = NULL; ggml_backend_buffer_t buffer = NULL; struct ggml_context * ctx = NULL; @@ -102,59 +101,38 @@ float ggml_tensor_get_f32(const ggml_tensor* tensor, int l, int k = 0, int j = 0 return *(float*)((char*)(tensor->data) + i * tensor->nb[3] + j * tensor->nb[2] + k * tensor->nb[1] + l * tensor->nb[0]); } -void load_model(test_model & model, bool use_gpu = false) { - float Query[30] = { // [3, 4, 2] - // z0 - 2, 4, 2, - 4, 2, 1, - 4, 1, 3, - 4, 2, 2, - - // z1 - 2, 1, 1, - 4, 2, 1, - 1, 1, 3, - 4, 2, 1 - }; +void load_model(test_model & model, int head_dim, int batch_size, int kv_size, int num_heads) { + float* query = new float[head_dim * batch_size * num_heads]; + float* key = new float[head_dim * kv_size * num_heads]; + float* value = new float[head_dim * kv_size * num_heads]; + float* mask = new float[kv_size * batch_size]; - float Key[24] = { // [3, 4, 2] - // z0 - 2, 4, 2, - 4, 2, 1, - 4, 2, 3, - 1, 2, 1, - - // z1 - 3, 1, 3, - 4, 2, 1, - 1, 1, 2, - 4, 3, 1 - }; + for(int i = 0; i < head_dim*batch_size*num_heads;i ++) { + query[i] = i % 3 ? 2.0f : 1.5f; + } - float Value[24] = { // [4, 3, 2] - // z0 - 2, 4, 2, 1, - 2, 1, 4, 2, - 1, 4, 2, 3, + for(int i = 0; i < head_dim*kv_size*num_heads;i ++) { + key[i] = i % 3 ? 2.3f : 2.8f; + value[i] = i % 3 ? 3.5f : 1.5f; + } - // z1 - 1, 4, 2, 1, - 2, 1, 1, 2, - 1, 4, 3, 3, - }; + for(int i = 0; i < batch_size*kv_size;i ++) { + mask[i] = i % 3 ? 1.0f : 0.0f; + } size_t buffer_size = 0; { - buffer_size += 30 * ggml_type_sizef(GGML_TYPE_F32); // tensor q - buffer_size += 24 * ggml_type_sizef(GGML_TYPE_F32); // tensor k - buffer_size += 24 * ggml_type_sizef(GGML_TYPE_F32); // tensor v + buffer_size += head_dim * batch_size * num_heads * ggml_type_sizef(GGML_TYPE_F32); // tensor q + buffer_size += head_dim * kv_size * num_heads * ggml_type_sizef(GGML_TYPE_F16); // tensor k + buffer_size += head_dim * kv_size * num_heads * ggml_type_sizef(GGML_TYPE_F16); // tensor v + buffer_size += batch_size * kv_size * ggml_type_sizef(GGML_TYPE_F32); // tensor mask buffer_size += 1024; } printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor)); printf("%s: backend buffer size = %0.2f MB\n", __func__, (buffer_size/ 1024.f/ 1024.f)); - int num_tensors = 3; + int num_tensors = 4; struct ggml_init_params params { /*.mem_size =*/ ggml_tensor_overhead() * num_tensors, /*.mem_buffer =*/ NULL, @@ -163,12 +141,10 @@ void load_model(test_model & model, bool use_gpu = false) { // initialize the backend #ifdef GGML_USE_CUBLAS - if (use_gpu) { - fprintf(stderr, "%s: using CUDA backend\n", __func__); - model.backend = ggml_backend_cuda_init(0); - if (!model.backend) { - fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); - } + fprintf(stderr, "%s: using CUDA backend\n", __func__); + model.backend = ggml_backend_cuda_init(0); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); } #endif @@ -183,9 +159,10 @@ void load_model(test_model & model, bool use_gpu = false) { model.ctx = ggml_init(params); // create tensors - model.q = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, 3, 4, 2); - model.k = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, 3, 4, 2); - model.v = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, 4, 3, 2); + model.q = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, head_dim, batch_size, num_heads); + model.k = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F16, head_dim, kv_size, num_heads); + model.v = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F16, head_dim, kv_size, num_heads); + model.msk = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, kv_size, batch_size); // create a allocator ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer); @@ -194,12 +171,18 @@ void load_model(test_model & model, bool use_gpu = false) { ggml_allocr_alloc(alloc, model.q); ggml_allocr_alloc(alloc, model.k); ggml_allocr_alloc(alloc, model.v); + ggml_allocr_alloc(alloc, model.msk); + + ggml_fp16_t* k_f16 = new ggml_fp16_t[head_dim * kv_size * num_heads]; + ggml_fp16_t* v_f16 = new ggml_fp16_t[head_dim * kv_size * num_heads]; - ggml_backend_tensor_set(model.q, Query, 0, ggml_nbytes(model.q)); - ggml_backend_tensor_set(model.k, Key, 0, ggml_nbytes(model.k)); - ggml_backend_tensor_set(model.v, Value, 0, ggml_nbytes(model.v)); + ggml_fp32_to_fp16_row(key, k_f16, head_dim * kv_size * num_heads); + ggml_fp32_to_fp16_row(value, v_f16, head_dim * kv_size * num_heads); - ggml_allocr_free(alloc); + ggml_backend_tensor_set(model.q, query, 0, ggml_nbytes(model.q)); + ggml_backend_tensor_set(model.k, k_f16, 0, ggml_nbytes(model.k)); + ggml_backend_tensor_set(model.v, v_f16, 0, ggml_nbytes(model.v)); + ggml_backend_tensor_set(model.msk, mask, 0, ggml_nbytes(model.msk)); } struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * allocr) { @@ -218,7 +201,7 @@ struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * a struct ggml_cgraph * gf = ggml_new_graph(ctx0); if(!model.naive_attn) { - struct ggml_tensor* result = ggml_flash_attn(ctx0, model.q, model.k, model.v, false); + struct ggml_tensor* result = ggml_flash_attn_ext(ctx0, model.q, model.k, model.v, model.msk, 1.0f / sqrtf(model.q->ne[0])); ggml_build_forward_expand(gf, result); } else { struct ggml_tensor* kq = ggml_mul_mat(ctx0, model.k, model.q); @@ -350,8 +333,7 @@ int main(int argc, char ** argv) ggml_time_init(); - - load_model(model, true); + load_model(model, 16, 32, 128, 2); ggml_backend_buffer_t buf_compute; // for compute struct ggml_allocr * allocr = NULL; @@ -385,7 +367,10 @@ int main(int argc, char ** argv) if(i > 0 && (i % result->ne[0] == 0)) { printf("\n"); } - printf("%2.6f ", data[i]); + if(i > 0 && (i % (result->ne[0] * result->ne[2]) == 0)) { + printf("\n\n"); + } + printf("%2.4f ", data[i]); } } From b1479dfbc574dc2b0ea8a7426f44011f73a118fc Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Wed, 31 Jan 2024 12:28:48 -0500 Subject: [PATCH 43/58] fix kernel --- ggml-cuda.cu | 103 ++++++++++++++++++--------------- tests/test-flash-attention.cpp | 2 +- 2 files changed, 56 insertions(+), 49 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 5229e15d2774a..fe24935a49937 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6158,9 +6158,9 @@ static __global__ void flash_attn_f32( } #if __CUDA_ARCH__ >= CC_VOLTA -typedef nvcuda::wmma::fragment half16x16_a; -typedef nvcuda::wmma::fragment half16x16_b; -typedef nvcuda::wmma::fragment half16x16_bT; +typedef nvcuda::wmma::fragment half16x16_a; +typedef nvcuda::wmma::fragment half16x16_b; +typedef nvcuda::wmma::fragment half16x16_bT; typedef nvcuda::wmma::fragment half16x16_acc; // based on metal version @@ -6204,15 +6204,15 @@ static __global__ void flash_attn_ext_f16( const int D16 = D/16; const int Q16 = Q/16; const int NW = WARP_SIZE; - const int SH = (C + D); // shared memory per simdgroup in (half) + const int SH = (C + Q); // shared memory per simdgroup in (half) const int T = D + num_warps*SH; // shared memory size per query in (half) const int T2 = T/2; // shared memory size per query in (half2) extern __shared__ half __flash_attn_f16_shmem[]; // pq - half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data - half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2 + half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data + half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2 half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix half16x16_acc lo[Q16][D16]; @@ -6249,7 +6249,7 @@ static __global__ void flash_attn_ext_f16( float S[Q]; float M[Q]; - for(int i = 0; i < Q;i ++) { + for(int i = 0; i < Q; i++) { S[i] = 0.0f; M[i] = -INFINITY; } @@ -6288,7 +6288,7 @@ static __global__ void flash_attn_ext_f16( const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; // pointer to the mask - const float * mp = (const float *) (mask + (ir%ne31)*nb31); + const float * mp = mask ? (const float *) (mask + (ir%ne31)*nb31) : nullptr; // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns @@ -6305,7 +6305,7 @@ static __global__ void flash_attn_ext_f16( for (int64_t i = 0; i < D16; ++i) { half16x16_bT mk; // transposed key - nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); // transpose + nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); for (int64_t j = 0; j < Q16; ++j) { nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]); @@ -6314,14 +6314,14 @@ static __global__ void flash_attn_ext_f16( // mqk = mqk*scale + mask for (int64_t j = 0; j < Q16; ++j) { - const float* msk_p = mp + 16*j*(nb31/sizeof(float)) + ic + 16*cc; - int64_t msk_ne_row = nb31/sizeof(float); + // const float* msk_p = mp + 16*j*(nb31/sizeof(float)) + ic + 16*cc; + // int64_t msk_ne_row = nb31/sizeof(float); for (uint32_t i = 0; i < mqk[j].num_elements; i++) { - int msk_col = i % 16; - int msk_row = i / 16; - mqk[j].x[i] = __float2half(scale * __half2float(mqk[j].x[i]) + msk_p[msk_col + msk_row*msk_ne_row]); + // int msk_col = i % 16; + // int msk_row = i / 16; + mqk[j].x[i] = __float2half(scale) * mqk[j].x[i]; // __half2float() + msk_p[msk_col + msk_row*msk_ne_row]); } - nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_col_major); + nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); } } } @@ -6370,11 +6370,11 @@ static __global__ void flash_attn_ext_f16( // create a QxQ diagonal matrix for rescaling the output if (lane_id == j) { - ss[j*T + C + j] = ms; + ss[j*T + C + j] = __float2half(ms); } for (int64_t p = lane_id; p < C; p += NW) { - const float s = ss[j*T + p]; + const float s = __half2float(ss[j*T + p]); const float vs = s == -INFINITY ? 0.0f : expf(s - M[j]); @@ -6393,14 +6393,18 @@ static __global__ void flash_attn_ext_f16( // O = diag(ms)*O for (int64_t j = 0; j < Q16; ++j) { - half16x16_a mm; - half16x16_b zro; + // half16x16_a mm; + // half16x16_b zro; - nvcuda::wmma::fill_fragment(zro, 0.0); - nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); + // nvcuda::wmma::fill_fragment(zro, 0.0); + // nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); for (int64_t i = 0; i < D16; ++i) { - nvcuda::wmma::mma_sync(lo[j][i], mm, zro, lo[j][i]); + //nvcuda::wmma::mma_sync(lo[j][i], mm, zro, lo[j][i]); + for (uint32_t k = 0; k < 16*16; k++) { + half tmp = ss[(16*j + k%16)*T + C + 16*j + k%16]; + lo[j][i].x[k] = tmp * lo[j][i].x[k]; + } } } @@ -6444,7 +6448,7 @@ static __global__ void flash_attn_ext_f16( if (warp_id == sg) { for (int64_t j = 0; j < Q16; ++j) { for (int64_t i = 0; i < D16; ++i) { - nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_col_major); + nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); } } } @@ -6487,13 +6491,13 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::fill_fragment(t2, 0.0); nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); nvcuda::wmma::mma_sync(t2, ms1, t, t2); - - // t <- lo - for (uint32_t k = 0; k < t.num_elements; k++) { - t.x[k] = lo[j][i].x[k]; - } + // store temporally 'lo' data + nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); + // load 'lo' data into t + nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2); } } @@ -6504,22 +6508,20 @@ static __global__ void flash_attn_ext_f16( if (warp_id == 0) { for (int64_t j = 0; j < Q16; ++j) { for (int64_t i = 0; i < D16; ++i) { - nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_col_major); + nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); } } } - float2 * dst2 = (float2 *) dst; + // float2 * dst2 = (float2 *) dst; // final rescale with 1/S and store to global memory if (warp_id == 0) { for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { const float S = __half2float(ss[j*T + 0]); - for (int64_t i = lane_id; i < D2; i += NW) { - dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + i] = __half22float2(sq2[j*T2 + i]); - dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + i].x /= S; - dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + i].y /= S; + for (int64_t i = lane_id; i < D; i += NW) { + dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i]) / S; } } } @@ -10526,13 +10528,17 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * GGML_ASSERT(Q->type == GGML_TYPE_F32); GGML_ASSERT(K->type == GGML_TYPE_F16); GGML_ASSERT(V->type == GGML_TYPE_F16); - GGML_ASSERT(mask->type == GGML_TYPE_F32); + if(mask) { + GGML_ASSERT(mask->type == GGML_TYPE_F32); + } GGML_ASSERT(KQV->type == GGML_TYPE_F32); GGML_ASSERT(Q->backend == GGML_BACKEND_GPU); GGML_ASSERT(K->backend == GGML_BACKEND_GPU); GGML_ASSERT(V->backend == GGML_BACKEND_GPU); - GGML_ASSERT(mask->backend == GGML_BACKEND_GPU); + if(mask) { + GGML_ASSERT(mask->backend == GGML_BACKEND_GPU); + } GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU); ggml_cuda_set_device(g_main_device); @@ -10541,7 +10547,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) Q->extra; ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) K->extra; ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) V->extra; - ggml_tensor_extra_gpu * src3_extra = (ggml_tensor_extra_gpu *) mask->extra; + ggml_tensor_extra_gpu * src3_extra = mask ? (ggml_tensor_extra_gpu *) mask->extra : nullptr; ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) KQV->extra; float scale; @@ -10549,13 +10555,14 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * const int nqpb = 16; // queries per block const int ncpw = 32; // cache values per warp (does not work for other values) - const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, 32)) : 4; + // const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, 32)) : 4; + const int nwarps = 1; dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); dim3 block_dim(32, nwarps, 1); - int shmem = nqpb*(Q->ne[0] + nwarps*(Q->ne[0] + 1*ncpw))*(sizeof(float)/2); - printf("shared memory: %d bytes [%i, %i, %i]\n\n", shmem, Q->ne[0], Q->ne[1], Q->ne[2]); + int shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); + printf("shared memory: %d bytes [%i, %i, %i] scale = %f\n\n", shmem, Q->ne[0], Q->ne[1], Q->ne[2], scale); switch (Q->ne[0]) { case 16: @@ -10564,12 +10571,12 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key (const char *) src2_extra->data_device[g_main_device], // Value - (const char *) src3_extra->data_device[g_main_device], // Mask + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask (float *) dst_extra->data_device[g_main_device], // dst scale, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask->ne[1], mask->nb[1], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, Q->nb[1], Q->nb[2], Q->nb[3], K->nb[1], K->nb[2], K->nb[3], KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] @@ -10581,12 +10588,12 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key (const char *) src2_extra->data_device[g_main_device], // Value - (const char *) src3_extra->data_device[g_main_device], // Mask + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask (float *) dst_extra->data_device[g_main_device], // dst scale, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask->ne[1], mask->nb[1], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, Q->nb[1], Q->nb[2], Q->nb[3], K->nb[1], K->nb[2], K->nb[3], KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] @@ -10598,12 +10605,12 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key (const char *) src2_extra->data_device[g_main_device], // Value - (const char *) src3_extra->data_device[g_main_device], // Mask + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask (float *) dst_extra->data_device[g_main_device], // dst scale, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask->ne[1], mask->nb[1], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, Q->nb[1], Q->nb[2], Q->nb[3], K->nb[1], K->nb[2], K->nb[3], KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] @@ -10615,12 +10622,12 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key (const char *) src2_extra->data_device[g_main_device], // Value - (const char *) src3_extra->data_device[g_main_device], // Mask + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask (float *) dst_extra->data_device[g_main_device], // dst scale, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask->ne[1], mask->nb[1], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, Q->nb[1], Q->nb[2], Q->nb[3], K->nb[1], K->nb[2], K->nb[3], KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] diff --git a/tests/test-flash-attention.cpp b/tests/test-flash-attention.cpp index 5d83eeabd791a..d4457a53e5b4b 100644 --- a/tests/test-flash-attention.cpp +++ b/tests/test-flash-attention.cpp @@ -201,7 +201,7 @@ struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * a struct ggml_cgraph * gf = ggml_new_graph(ctx0); if(!model.naive_attn) { - struct ggml_tensor* result = ggml_flash_attn_ext(ctx0, model.q, model.k, model.v, model.msk, 1.0f / sqrtf(model.q->ne[0])); + struct ggml_tensor* result = ggml_flash_attn_ext(ctx0, model.q, model.k, model.v, nullptr, 1.0f / sqrtf(model.q->ne[0])); ggml_build_forward_expand(gf, result); } else { struct ggml_tensor* kq = ggml_mul_mat(ctx0, model.k, model.q); From 8ad92dc1ec9aa6549c68900daa7ab93b57fa3ae5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 31 Jan 2024 19:17:16 +0200 Subject: [PATCH 44/58] ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext --- ggml-cuda.cu | 20 +++++++++---------- ggml-metal.m | 6 ++++++ ggml-metal.metal | 40 ++++++++++++++++++-------------------- ggml.c | 13 +++++++++---- ggml.h | 12 +++++++----- llama.cpp | 40 ++++++++++++++++++++++---------------- tests/test-backend-ops.cpp | 10 +++++----- 7 files changed, 79 insertions(+), 62 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index e565957421795..c57a031e4060c 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -5917,7 +5917,7 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int } template -static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { +static __global__ void soft_max_f16(const float * x, const half * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template; const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2; @@ -5952,12 +5952,12 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds if (need_check && col_data + 0 >= ncols_data) { val.x = -INFINITY; } else { - val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f); + val.x = x[ix + 0]*scale + (y ? __half2float(y[iy + 0]) : 0.0f); } if (need_check && col_data + WARP_SIZE >= ncols_data) { val.y = -INFINITY; } else { - val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f); + val.y = x[ix + WARP_SIZE]*scale + (y ? __half2float(y[iy + WARP_SIZE]) : 0.0f); } if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) { vals[col_smem] = val; @@ -6047,7 +6047,7 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds } template -static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { +static __global__ void soft_max_f32(const float * x, const half * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; const int tid = threadIdx.x; @@ -6077,7 +6077,7 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - const float val = x[ix]*scale + (y ? y[iy] : 0.0f); + const float val = x[ix]*scale + (y ? __half2float(y[iy]) : 0.0f); vals[col] = val; max_val = max(max_val, val); } @@ -7585,7 +7585,7 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols diag_mask_inf_f32<<>>(x, dst, ncols_x, rows_per_channel, n_past); } -static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { +static void soft_max_f16_cuda(const float * x, const half * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { int nth = WARP_SIZE; while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); @@ -7628,7 +7628,7 @@ static void soft_max_f16_cuda(const float * x, const float * y, float * dst, con } } -static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { +static void soft_max_f32_cuda(const float * x, const half * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { int nth = WARP_SIZE; while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); @@ -9060,7 +9060,7 @@ static void ggml_cuda_op_soft_max( GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16); // src1 contains mask and it is optional const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); @@ -9080,9 +9080,9 @@ static void ggml_cuda_op_soft_max( #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && CUDART_VERSION >= CUDART_HMAX if (use_f16_soft_max) { - soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); + soft_max_f16_cuda(src0_dd, src1 ? (const half *) src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); } else { - soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); + soft_max_f32_cuda(src0_dd, src1 ? (const half *) src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); } (void) dst; diff --git a/ggml-metal.m b/ggml-metal.m index 15e5568f960f1..e00069624551f 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1187,6 +1187,8 @@ static bool ggml_metal_graph_compute( } break; case GGML_OP_SOFT_MAX: { + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16); + int nth = 32; // SIMD width id pipeline = nil; @@ -2213,6 +2215,10 @@ static bool ggml_metal_graph_compute( id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; + GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16); + GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) && + "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); + const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); const int64_t ne31 = src3 ? src3->ne[1] : 0; const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); diff --git a/ggml-metal.metal b/ggml-metal.metal index b2e40715d4f2d..04c1aaf9cdfb9 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -349,9 +349,9 @@ kernel void kernel_sum_rows( } kernel void kernel_soft_max( - device const float * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -366,9 +366,9 @@ kernel void kernel_soft_max( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr; - device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const half * pmask = src1 != src0 ? (device const half *) src1 + i01*ne00 : nullptr; + device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); // parallel max float lmax = -INFINITY; @@ -435,14 +435,14 @@ kernel void kernel_soft_max( } kernel void kernel_soft_max_4( - device const float * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, constant float & scale, - threadgroup float * buf [[threadgroup(0)]], + threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]], @@ -452,15 +452,15 @@ kernel void kernel_soft_max_4( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr; - device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + device const half4 * pmask = src1 != src0 ? (device const half4 *) src1 + i01*ne00/4 : nullptr; + device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; // parallel max float4 lmax4 = -INFINITY; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)); + lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4) (pmask ? pmask[i00] : 0.0f)); } const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); @@ -486,7 +486,7 @@ kernel void kernel_soft_max_4( // parallel sum float4 lsum4 = 0.0f; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val); + const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4) (pmask ? pmask[i00] : 0.0f)) - max_val); lsum4 += exp_psrc4; pdst4[i00] = exp_psrc4; } @@ -2144,13 +2144,11 @@ kernel void kernel_flash_attn_ext_f16( } } - const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - // pointer to the mask - device const float * mp = (device const float *) (mask + (ir%ne31)*nb31); + device const half * mp = (device const half *) (mask + iq1*nb31); // prepare diagonal scale matrix - simdgroup_float8x8 mscale(scale); + simdgroup_half8x8 mscale(scale); // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns @@ -2176,8 +2174,8 @@ kernel void kernel_flash_attn_ext_f16( // mqk = mqk*scale + mask for (int64_t j = 0; j < Q8; ++j) { - simdgroup_float8x8 mm; - simdgroup_load(mm, mp + 8*j*(nb31/sizeof(float)) + ic + 8*cc, nb31/sizeof(float), 0, false); + simdgroup_half8x8 mm; + simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false); simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm); simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false); diff --git a/ggml.c b/ggml.c index 466a8cdec3c9d..59a4c05a12ffe 100644 --- a/ggml.c +++ b/ggml.c @@ -5085,6 +5085,7 @@ static struct ggml_tensor * ggml_soft_max_impl( bool inplace) { GGML_ASSERT(ggml_is_contiguous(a)); if (mask) { + GGML_ASSERT(mask->type == GGML_TYPE_F16); GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(mask->ne[2] == 1); GGML_ASSERT(mask->ne[3] == 1); @@ -5854,6 +5855,8 @@ struct ggml_tensor * ggml_flash_attn_ext( GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(mask->ne[2] == 1); GGML_ASSERT(mask->ne[3] == 1); + GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) && + "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big"); //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); } @@ -11552,12 +11555,14 @@ static void ggml_compute_forward_soft_max_f32( float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); // broadcast the mask across rows - float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL; + ggml_fp16_t * mp = src1 ? (ggml_fp16_t *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL; ggml_vec_cpy_f32 (nc, wp, sp); ggml_vec_scale_f32(nc, wp, scale); if (mp) { - ggml_vec_acc_f32(nc, wp, mp); + for (int i = 0; i < nc; ++i) { + wp[i] += GGML_FP16_TO_FP32(mp[i]); + } } #ifndef NDEBUG @@ -13760,7 +13765,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( memset(V16, 0, D*sizeof(ggml_fp16_t)); - const float * mp = mask ? (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]) : NULL; + const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; // k indices const int ik3 = iq3 / rk3; @@ -13774,7 +13779,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( // loop over n_kv and n_head_kv // ref: https://arxiv.org/pdf/2112.05682.pdf for (int64_t ic = 0; ic < nek1; ++ic) { - const float mv = mp ? mp[ic] : 0.0f; + const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f; if (mv == -INFINITY) { continue; } diff --git a/ggml.h b/ggml.h index a83ff8035f9ea..74ce1abd4d500 100644 --- a/ggml.h +++ b/ggml.h @@ -1646,11 +1646,13 @@ extern "C" { struct ggml_tensor * v, bool masked); - // q: [n_embd, n_batch, n_head, 1] - // k: [n_embd, n_kv, n_head_kv, 1] - // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !! - // mask: [n_kv, n_batch, 1, 1] - // res: [n_embd, n_head, n_batch, 1] !! permuted !! +#define GGML_KQ_MASK_PAD 32 + + // q: [n_embd, n_batch, n_head, 1] + // k: [n_embd, n_kv, n_head_kv, 1] + // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !! + // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !! + // res: [n_embd, n_head, n_batch, 1] !! permuted !! GGML_API struct ggml_tensor * ggml_flash_attn_ext( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/llama.cpp b/llama.cpp index 1f8ecc19b4e0c..fe25839669efc 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4721,7 +4721,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -4905,7 +4905,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -5026,7 +5026,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -5148,7 +5148,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); @@ -5245,7 +5245,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); if (do_rope_shift) { @@ -5448,7 +5448,7 @@ struct llm_build_context { cb(inpL, "inp_embd", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); for (int il = 0; il < n_layer; ++il) { @@ -5538,7 +5538,7 @@ struct llm_build_context { cb(inpL, "inp_embd", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); inpL = llm_build_norm(ctx0, inpL, hparams, @@ -5631,7 +5631,7 @@ struct llm_build_context { cb(inpL, "inp_embd", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); for (int il = 0; il < n_layer; ++il) { @@ -5731,7 +5731,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -5854,7 +5854,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -5968,7 +5968,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -6089,7 +6089,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -6211,7 +6211,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -6318,7 +6318,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); @@ -6416,7 +6416,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -6524,7 +6524,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -10250,7 +10250,10 @@ struct llama_context * llama_new_context_with_model( const auto & hparams = model->hparams; auto & cparams = ctx->cparams; - cparams.n_batch = params.n_batch; + // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask + // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext) + cparams.n_batch = std::max((uint32_t) GGML_KQ_MASK_PAD, params.n_batch); + cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; cparams.yarn_ext_factor = params.yarn_ext_factor; @@ -10430,6 +10433,9 @@ struct llama_context * llama_new_context_with_model( ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true)); + // zero-out the input buffer to prevent NaNs in padded tensors + ggml_backend_buffer_clear(ctx->buf_input, 0); + LLAMA_LOG_INFO("%s: %10s input buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(ctx->buf_input), ggml_backend_buffer_get_size(ctx->buf_input) / 1024.0 / 1024.0); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 0f31c00f9672c..b1b30b91c9c6b 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1101,7 +1101,7 @@ struct test_soft_max : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_tensor * b = nullptr; - if (mask) { b = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]); } + if (mask) { b = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, ne[0], ne[1]); } ggml_tensor * out = ggml_soft_max_ext(ctx, a, b, scale); return out; } @@ -1472,7 +1472,7 @@ struct test_flash_attn_ext : public test_case { ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); - ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1); + ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1); ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs)); return out; } @@ -1506,7 +1506,7 @@ struct test_attn : public test_case { ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); // transposed - ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1); + ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, nb, 1, 1); struct ggml_tensor * cur; @@ -1793,7 +1793,7 @@ struct test_llama : public test_llm { struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1); + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, hp.n_kv, hp.n_tokens, 1); ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400); ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400); @@ -1915,7 +1915,7 @@ struct test_falcon : public test_llm { struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1); + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, hp.n_kv, hp.n_tokens, 1); ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400); ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400); From 0afe47fa5fdda0ff9191ca70241a9fe88364d8cc Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Wed, 31 Jan 2024 15:43:42 -0500 Subject: [PATCH 45/58] fix naive implementation --- tests/test-flash-attention.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test-flash-attention.cpp b/tests/test-flash-attention.cpp index d4457a53e5b4b..1f779b0d4f08a 100644 --- a/tests/test-flash-attention.cpp +++ b/tests/test-flash-attention.cpp @@ -207,7 +207,9 @@ struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * a struct ggml_tensor* kq = ggml_mul_mat(ctx0, model.k, model.q); kq = ggml_scale_inplace(ctx0, kq, 1.0f / sqrtf((float)model.q->ne[0])); kq = ggml_soft_max(ctx0, kq); - kq = ggml_mul_mat(ctx0, model.v, kq); + kq = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, model.v)), kq); + kq = ggml_permute (ctx0, kq, 0, 2, 1, 3); + //kq = ggml_cont_2d (ctx0, kq, model.q->ne[0] * model.q->ne[2], model.q->ne[1]); ggml_build_forward_expand(gf, kq); } From fd878f71ed370eb34b85f89e27f07821a9b2c10b Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Wed, 31 Jan 2024 16:22:11 -0500 Subject: [PATCH 46/58] cuda: mask as fp16 --- ggml-cuda.cu | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 35e2af0f45869..86afb01338f73 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6529,7 +6529,7 @@ static __global__ void flash_attn_ext_f16( const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; // pointer to the mask - const float * mp = mask ? (const float *) (mask + (ir%ne31)*nb31) : nullptr; + const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr; // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns @@ -6555,12 +6555,9 @@ static __global__ void flash_attn_ext_f16( // mqk = mqk*scale + mask for (int64_t j = 0; j < Q16; ++j) { - // const float* msk_p = mp + 16*j*(nb31/sizeof(float)) + ic + 16*cc; - // int64_t msk_ne_row = nb31/sizeof(float); for (uint32_t i = 0; i < mqk[j].num_elements; i++) { - // int msk_col = i % 16; - // int msk_row = i / 16; - mqk[j].x[i] = __float2half(scale) * mqk[j].x[i]; // __half2float() + msk_p[msk_col + msk_row*msk_ne_row]); + // TODO: process mask + mqk[j].x[i] = __float2half(scale) * mqk[j].x[i]; } nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); } @@ -9216,7 +9213,7 @@ static void ggml_cuda_op_dequantize_mul_mat_vec( src1_dfloat = src1_dfloat_a.alloc(ne00); ggml_cpy_f32_f16_cuda((const char *) src1_ddf_i, (char *) src1_dfloat, ne00, ne00, 1, sizeof(float), 0, 0, - ne00, 1, sizeof(half), 0, 0, stream); + ne00, 1, sizeof(half), 0, 0, 0, 0, 0, 0, stream); } #else const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion @@ -10891,19 +10888,18 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * GGML_ASSERT(Q->type == GGML_TYPE_F32); GGML_ASSERT(K->type == GGML_TYPE_F16); GGML_ASSERT(V->type == GGML_TYPE_F16); - if(mask) { - GGML_ASSERT(mask->type == GGML_TYPE_F32); - } GGML_ASSERT(KQV->type == GGML_TYPE_F32); GGML_ASSERT(Q->backend == GGML_BACKEND_GPU); GGML_ASSERT(K->backend == GGML_BACKEND_GPU); GGML_ASSERT(V->backend == GGML_BACKEND_GPU); - if(mask) { - GGML_ASSERT(mask->backend == GGML_BACKEND_GPU); - } GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU); + GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); + GGML_ASSERT(!mask || mask->backend == GGML_BACKEND_GPU); + GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 8) && + "the Flash-Attention CUDA kernel requires the mask to be padded to 8 and at least n_queries big"); + ggml_cuda_set_device(g_main_device); const cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; @@ -10925,7 +10921,6 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * dim3 block_dim(32, nwarps, 1); int shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); - printf("shared memory: %d bytes [%i, %i, %i] scale = %f\n\n", shmem, Q->ne[0], Q->ne[1], Q->ne[2], scale); switch (Q->ne[0]) { case 16: From 71b69aa7fd0aee89c4750d230bee7a4601d8fc1f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 09:40:56 +0200 Subject: [PATCH 47/58] cuda : fix flash_attn kernel to produce same results as CPU --- ggml-cuda.cu | 66 +++++++++++++++++++++++--------------- tests/test-backend-ops.cpp | 2 +- 2 files changed, 42 insertions(+), 26 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 86afb01338f73..0d23c12445c30 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6445,7 +6445,7 @@ static __global__ void flash_attn_ext_f16( const int D16 = D/16; const int Q16 = Q/16; const int NW = WARP_SIZE; - const int SH = (C + Q); // shared memory per simdgroup in (half) + const int SH = (C + 2*Q); // shared memory per simdgroup in (half) const int T = D + num_warps*SH; // shared memory size per query in (half) const int T2 = T/2; // shared memory size per query in (half2) @@ -6526,11 +6526,16 @@ static __global__ void flash_attn_ext_f16( } } - const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - // pointer to the mask const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr; + // prepare diagonal scale matrix + half16x16_b mscale; + for (int i = 0; i < 16; ++i) { + ss[i*T + i] = __float2half(scale); + } + nvcuda::wmma::load_matrix_sync(mscale, ss, T); + // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) { @@ -6555,10 +6560,15 @@ static __global__ void flash_attn_ext_f16( // mqk = mqk*scale + mask for (int64_t j = 0; j < Q16; ++j) { - for (uint32_t i = 0; i < mqk[j].num_elements; i++) { - // TODO: process mask - mqk[j].x[i] = __float2half(scale) * mqk[j].x[i]; - } + half16x16_a mqka; + half16x16_acc mm; + + // convert accumulator to matrix_a + nvcuda::wmma::store_matrix_sync( ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync (mqka, ss + 16*j*T + 16*cc, T); + + nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major); + nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mm); nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); } } @@ -6631,18 +6641,19 @@ static __global__ void flash_attn_ext_f16( // O = diag(ms)*O for (int64_t j = 0; j < Q16; ++j) { - // half16x16_a mm; - // half16x16_b zro; + half16x16_a mm; + half16x16_b lob; - // nvcuda::wmma::fill_fragment(zro, 0.0); - // nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); + nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); for (int64_t i = 0; i < D16; ++i) { - //nvcuda::wmma::mma_sync(lo[j][i], mm, zro, lo[j][i]); - for (uint32_t k = 0; k < 16*16; k++) { - half tmp = ss[(16*j + k%16)*T + C + 16*j + k%16]; - lo[j][i].x[k] = tmp * lo[j][i].x[k]; - } + // convert accumulator to matrix_b + // TODO: try to avoid the extra QxQ matrix in shared memory needed for this conversion + nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + Q, lo[j][i], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + Q, T); + + nvcuda::wmma::fill_fragment(lo[j][i], 0.0); + nvcuda::wmma::mma_sync(lo[j][i], mm, lob, lo[j][i]); } } @@ -6732,10 +6743,11 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::fill_fragment(t2, 0.0); nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); nvcuda::wmma::mma_sync(t2, ms1, t, t2); - // store temporally 'lo' data - nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); - // load 'lo' data into t - nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); + + // convert accumulator to matrix_b + nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync (t, sq + 16*j*T + i*16, T); + nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2); } } @@ -10897,8 +10909,8 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); GGML_ASSERT(!mask || mask->backend == GGML_BACKEND_GPU); - GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 8) && - "the Flash-Attention CUDA kernel requires the mask to be padded to 8 and at least n_queries big"); + GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && + "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); ggml_cuda_set_device(g_main_device); const cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; @@ -10914,13 +10926,17 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * const int nqpb = 16; // queries per block const int ncpw = 32; // cache values per warp (does not work for other values) - // const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, 32)) : 4; - const int nwarps = 1; + + const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? + const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, nwarps_max)) : 4; dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); dim3 block_dim(32, nwarps, 1); - int shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); + // TODO: compare to Metal, here we need extra `nqpb` space in order to do the diag(ms)*O scaling + // try to avoid this + const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + 2*nqpb))*(sizeof(float)/2); + switch (Q->ne[0]) { case 16: diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b1b30b91c9c6b..e632142a74a13 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2214,7 +2214,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op for (int hs : { 128, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, }) { - for (int nb : { 1, 2, 4, 8, 512 }) { + for (int nb : { 1, 2, 4, 7, 8, 15, 16, 512 }) { test_cases.emplace_back(new test_attn (hs, nh, kv, nb)); test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb)); } From 2c04beeb81cce6f868c743634d4a5a74b47531c8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 14:03:03 +0200 Subject: [PATCH 48/58] cuda : avoid extra QxQ matrix in shared memory --- ggml-cuda.cu | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 0d23c12445c30..bdd50e2b6be4c 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6445,7 +6445,7 @@ static __global__ void flash_attn_ext_f16( const int D16 = D/16; const int Q16 = Q/16; const int NW = WARP_SIZE; - const int SH = (C + 2*Q); // shared memory per simdgroup in (half) + const int SH = (C + Q); // shared memory per simdgroup in (half) const int T = D + num_warps*SH; // shared memory size per query in (half) const int T2 = T/2; // shared memory size per query in (half2) @@ -6455,6 +6455,8 @@ static __global__ void flash_attn_ext_f16( half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2 half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix + + half16x16_acc zr; half16x16_acc lo[Q16][D16]; // load heads from Q to shared memory @@ -6470,6 +6472,8 @@ static __global__ void flash_attn_ext_f16( } } + nvcuda::wmma::fill_fragment(zr, 0.0); + // zero out lo for (int64_t j = 0; j < Q16; ++j) { for (int64_t i = 0; i < D16; ++i) { @@ -6648,13 +6652,15 @@ static __global__ void flash_attn_ext_f16( for (int64_t i = 0; i < D16; ++i) { // convert accumulator to matrix_b - // TODO: try to avoid the extra QxQ matrix in shared memory needed for this conversion - nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + Q, lo[j][i], T, nvcuda::wmma::mem_row_major); - nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + Q, T); + nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T); nvcuda::wmma::fill_fragment(lo[j][i], 0.0); nvcuda::wmma::mma_sync(lo[j][i], mm, lob, lo[j][i]); } + + // restore zeros + nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major); } // O = O + (Q*K^T)*V @@ -10928,14 +10934,13 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * const int ncpw = 32; // cache values per warp (does not work for other values) const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? + // TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, nwarps_max)) : 4; dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); dim3 block_dim(32, nwarps, 1); - // TODO: compare to Metal, here we need extra `nqpb` space in order to do the diag(ms)*O scaling - // try to avoid this - const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + 2*nqpb))*(sizeof(float)/2); + const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); switch (Q->ne[0]) { From 9a5c2a1681d3979d071fff0f1a9abece57f0841f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 15:00:47 +0200 Subject: [PATCH 49/58] cuda : switch to F16 scalars + tune warps for RTX 2060 --- ggml-cuda.cu | 94 ++++++++++++++++++++------------------ tests/test-backend-ops.cpp | 14 +++++- 2 files changed, 61 insertions(+), 47 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index bdd50e2b6be4c..330fc6290effa 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6491,8 +6491,8 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); { - float S[Q]; - float M[Q]; + half S[Q]; + half M[Q]; for(int i = 0; i < Q; i++) { S[i] = 0.0f; @@ -6579,67 +6579,68 @@ static __global__ void flash_attn_ext_f16( } // used to detect blocks full of -INF - float smax = -INFINITY; + half smax = -INFINITY; // online softmax if (C == 32) { for (int64_t j = 0; j < Q; ++j) { const int64_t p = lane_id; - const float m = M[j]; - const float s = __half2float(ss[j*T + p]); + const half m = M[j]; + const half s = ss[j*T + p]; - smax = warp_reduce_max(max(smax, s)); - M[j] = warp_reduce_max(max(M[j], s)); + smax = warp_reduce_max(__hmax(smax, s)); + M[j] = warp_reduce_max(__hmax(M[j], s)); - const float ms = m == -INFINITY ? 0.0f : expf(m - M[j]); - const float vs = s == -INFINITY ? 0.0f : expf(s - M[j]); + const half ms = __hisinf(m) ? 0.0f : expf(m - M[j]); + const half vs = __hisinf(s) ? 0.0f : expf(s - M[j]); S[j] = S[j]*ms + warp_reduce_sum(vs); // create a QxQ diagonal matrix for rescaling the output if (p == j) { - ss[j*T + C + j] = __float2half(ms); + ss[j*T + C + j] = ms; } // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = __float2half(vs); + ss[j*T + p] = vs; } } else { for (int64_t j = 0; j < Q; ++j) { - const float m = M[j]; + const half m = M[j]; for (int64_t p = lane_id; p < C; p += NW) { - const float s = __half2float(ss[j*T + p]); + const half s = ss[j*T + p]; - smax = warp_reduce_max(max(smax, s)); - M[j] = warp_reduce_max(max(M[j], s)); + smax = warp_reduce_max(__hmax(smax, s)); + M[j] = warp_reduce_max(__hmax(M[j], s)); } - const float ms = m == -INFINITY ? 0.0f : expf(m - M[j]); + const half ms = __hisinf(m) ? 0.0f : expf(m - M[j]); S[j] = S[j]*ms; // create a QxQ diagonal matrix for rescaling the output if (lane_id == j) { - ss[j*T + C + j] = __float2half(ms); + ss[j*T + C + j] = ms; } for (int64_t p = lane_id; p < C; p += NW) { - const float s = __half2float(ss[j*T + p]); + const half s = ss[j*T + p]; - const float vs = s == -INFINITY ? 0.0f : expf(s - M[j]); + const half vs = __hisinf(s) ? 0.0f : expf(s - M[j]); S[j] = S[j] + warp_reduce_sum(vs); // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = __float2half(vs); + ss[j*T + p] = vs; } } } + // skip -INF blocks - if (smax == -INFINITY) { + if (__hisinf(smax)) { continue; } @@ -6686,16 +6687,16 @@ static __global__ void flash_attn_ext_f16( // these are needed for reducing the results from the simdgroups (reuse the ss buffer) for (int64_t j = 0; j < Q; ++j) { if (lane_id == 0) { - ss[j*T + 0] = __float2half(S[j]); - ss[j*T + 1] = __float2half(M[j]); + ss[j*T + 0] = S[j]; + ss[j*T + 1] = M[j]; } } } // reduce the warps sequentially for (int64_t sg = 1; sg < num_warps; ++sg) { - float S = 0.0f; - float M = -INFINITY; + half S = 0.0f; + half M = -INFINITY; __syncthreads(); @@ -6713,25 +6714,25 @@ static __global__ void flash_attn_ext_f16( // the first simdgroup accumulates the results from the other simdgroups if (warp_id == 0) { for (int64_t j = 0; j < Q; ++j) { - const float S0 = __half2float(ss[j*T + 0]); - const float S1 = __half2float(ss[j*T + sg*SH + 0]); + const half S0 = ss[j*T + 0]; + const half S1 = ss[j*T + sg*SH + 0]; - const float M0 = __half2float(ss[j*T + 1]); - const float M1 = __half2float(ss[j*T + sg*SH + 1]); + const half M0 = ss[j*T + 1]; + const half M1 = ss[j*T + sg*SH + 1]; - M = max(M0, M1); + M = __hmax(M0, M1); - const float ms0 = M0 == -INFINITY ? 0.0f : expf(M0 - M); - const float ms1 = M1 == -INFINITY ? 0.0f : expf(M1 - M); + const half ms0 = __hisinf(M0) ? 0.0f : expf(M0 - M); + const half ms1 = __hisinf(M1) ? 0.0f : expf(M1 - M); S = S0*ms0 + S1*ms1; if (lane_id == 0) { - ss[j*T + 0] = __float2half(S); - ss[j*T + 1] = __float2half(M); + ss[j*T + 0] = S; + ss[j*T + 1] = M; - ss[j*T + C + j ] = __float2half(ms0); - ss[j*T + C + j + sg*SH] = __float2half(ms1); + ss[j*T + C + j ] = ms0; + ss[j*T + C + j + sg*SH] = ms1; } } @@ -6774,10 +6775,10 @@ static __global__ void flash_attn_ext_f16( // final rescale with 1/S and store to global memory if (warp_id == 0) { for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { - const float S = __half2float(ss[j*T + 0]); + const half S = ss[j*T + 0]; for (int64_t i = lane_id; i < D; i += NW) { - dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i]) / S; + dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S); } } } @@ -10930,12 +10931,15 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * float scale; memcpy(&scale, KQV->op_params, sizeof(float)); - const int nqpb = 16; // queries per block - const int ncpw = 32; // cache values per warp (does not work for other values) +#define NQPB 16 +#define NCPW 32 + + const int nqpb = NQPB; // queries per block + const int ncpw = NCPW; // cache values per warp (does not work for other values) const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? // TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why - const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, nwarps_max)) : 4; + const int nwarps = Q->ne[1] <= nqpb ? MAX(2, MIN(K->ne[1]/ncpw, nwarps_max)) : 2; dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); dim3 block_dim(32, nwarps, 1); @@ -10945,7 +10949,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * switch (Q->ne[0]) { case 16: - flash_attn_ext_f16<16, 16, 32> + flash_attn_ext_f16<16, NQPB, NCPW> <<>> ( (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key @@ -10962,7 +10966,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * ); break; case 64: - flash_attn_ext_f16<64, 16, 32> + flash_attn_ext_f16<64, NQPB, NCPW> <<>> ( (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key @@ -10979,7 +10983,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * ); break; case 80: - flash_attn_ext_f16<80, 16, 32> + flash_attn_ext_f16<80, NQPB, NCPW> <<>> ( (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key @@ -10996,7 +11000,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * ); break; case 128: - flash_attn_ext_f16<128, 16, 32> + flash_attn_ext_f16<128, NQPB, NCPW> <<>> ( (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index e632142a74a13..ff207e21b8ec3 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -572,9 +572,19 @@ struct test_case { // duplicate the op size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1; +#if 1 for (int i = 1; i < n_runs; i++) { gf->nodes[gf->n_nodes++] = out; } +#else + n_runs = 1000; + int n_nodes = gf->n_nodes; + for (int i = 1; i < n_runs; i++) { + for (int j = 0; j < n_nodes; j++) { + gf->nodes[gf->n_nodes++] = gf->nodes[j]; + } + } +#endif // calculate memory size_t mem = n_runs * op_size(out); @@ -2199,8 +2209,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); -#if 0 - for (int hs : { 64, 80, 96, 112, 128, 256, }) { +#if 1 + for (int hs : { 64, 80, 128, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, 2048, 4096, }) { for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) { From ac26f2702806744126b8edeb3d42dab1ce91cae1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 16:12:56 +0200 Subject: [PATCH 50/58] cuda : increase C to 128 for better performance --- ggml-cuda.cu | 59 +++++++++++++++++++++----------------- ggml.c | 2 +- llama.cpp | 3 +- tests/test-backend-ops.cpp | 2 +- 4 files changed, 37 insertions(+), 29 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 330fc6290effa..e7bf95bd1d616 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6495,8 +6495,8 @@ static __global__ void flash_attn_ext_f16( half M[Q]; for(int i = 0; i < Q; i++) { - S[i] = 0.0f; - M[i] = -INFINITY; + S[i] = __float2half(0.0f); + M[i] = __float2half(-INFINITY); } // assume K and V are same shape @@ -6579,7 +6579,7 @@ static __global__ void flash_attn_ext_f16( } // used to detect blocks full of -INF - half smax = -INFINITY; + half smax = __float2half(-INFINITY); // online softmax if (C == 32) { @@ -6592,8 +6592,8 @@ static __global__ void flash_attn_ext_f16( smax = warp_reduce_max(__hmax(smax, s)); M[j] = warp_reduce_max(__hmax(M[j], s)); - const half ms = __hisinf(m) ? 0.0f : expf(m - M[j]); - const half vs = __hisinf(s) ? 0.0f : expf(s - M[j]); + const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]); + const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]); S[j] = S[j]*ms + warp_reduce_sum(vs); @@ -6612,33 +6612,38 @@ static __global__ void flash_attn_ext_f16( for (int64_t p = lane_id; p < C; p += NW) { const half s = ss[j*T + p]; - smax = warp_reduce_max(__hmax(smax, s)); - M[j] = warp_reduce_max(__hmax(M[j], s)); + smax = __hmax(smax, s); + M[j] = __hmax(M[j], s); } - const half ms = __hisinf(m) ? 0.0f : expf(m - M[j]); + smax = warp_reduce_max(smax); + M[j] = warp_reduce_max(M[j]); - S[j] = S[j]*ms; + const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]); // create a QxQ diagonal matrix for rescaling the output if (lane_id == j) { ss[j*T + C + j] = ms; } + // local sum + half ls = 0.0f; + for (int64_t p = lane_id; p < C; p += NW) { const half s = ss[j*T + p]; - const half vs = __hisinf(s) ? 0.0f : expf(s - M[j]); + const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]); - S[j] = S[j] + warp_reduce_sum(vs); + ls += vs; // the P matrix from the paper (Q rows, C columns) ss[j*T + p] = vs; } + + S[j] = S[j]*ms + warp_reduce_sum(ls); } } - // skip -INF blocks if (__hisinf(smax)) { continue; @@ -6669,15 +6674,19 @@ static __global__ void flash_attn_ext_f16( for (int cc = 0; cc < C/16; ++cc) { const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); + half16x16_b mk[D16]; for (int64_t i = 0; i < D16; ++i) { - half16x16_b mk; - nvcuda::wmma::load_matrix_sync(mk, pv + i*16, nb21/sizeof(half)); + nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half)); + } - for (int64_t j = 0; j < Q16; ++j) { - half16x16_a mv; - nvcuda::wmma::load_matrix_sync(mv, ss + 16*j*T + 16*cc, T); + half16x16_a mv[Q16]; + for (int64_t j = 0; j < Q16; ++j) { + nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T); + } - nvcuda::wmma::mma_sync(lo[j][i], mv, mk, lo[j][i]); + for (int64_t j = 0; j < Q16; ++j) { + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]); } } } @@ -6695,8 +6704,8 @@ static __global__ void flash_attn_ext_f16( // reduce the warps sequentially for (int64_t sg = 1; sg < num_warps; ++sg) { - half S = 0.0f; - half M = -INFINITY; + half S = __float2half(0.0f); + half M = __float2half(-INFINITY); __syncthreads(); @@ -6722,8 +6731,8 @@ static __global__ void flash_attn_ext_f16( M = __hmax(M0, M1); - const half ms0 = __hisinf(M0) ? 0.0f : expf(M0 - M); - const half ms1 = __hisinf(M1) ? 0.0f : expf(M1 - M); + const half ms0 = __hisinf(M0) ? __float2half(0.0f) : hexp(M0 - M); + const half ms1 = __hisinf(M1) ? __float2half(0.0f) : hexp(M1 - M); S = S0*ms0 + S1*ms1; @@ -6770,8 +6779,6 @@ static __global__ void flash_attn_ext_f16( } } - // float2 * dst2 = (float2 *) dst; - // final rescale with 1/S and store to global memory if (warp_id == 0) { for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { @@ -9637,7 +9644,7 @@ static void ggml_cuda_op_soft_max( const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); - const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1; + const int64_t nrows_y = src1 ? src0->ne[1] : 1; // note: using number of queries since mask can be padded! float scale = 1.0f; memcpy(&scale, dst->op_params, sizeof(float)); @@ -10932,7 +10939,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * memcpy(&scale, KQV->op_params, sizeof(float)); #define NQPB 16 -#define NCPW 32 +#define NCPW 128 const int nqpb = NQPB; // queries per block const int ncpw = NCPW; // cache values per warp (does not work for other values) diff --git a/ggml.c b/ggml.c index 59a4c05a12ffe..ebd9c6b341080 100644 --- a/ggml.c +++ b/ggml.c @@ -5089,7 +5089,7 @@ static struct ggml_tensor * ggml_soft_max_impl( GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(mask->ne[2] == 1); GGML_ASSERT(mask->ne[3] == 1); - GGML_ASSERT(ggml_can_repeat_rows(mask, a)); + GGML_ASSERT(mask->ne[1] >= a->ne[1]); } bool is_node = false; diff --git a/llama.cpp b/llama.cpp index fe25839669efc..2330efff57bd3 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6881,7 +6881,8 @@ static int llama_decode_internal( // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important - kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); + // note: we pad the n_kv because certain GPU kernels require it (e.g. ggml_flash_attn_ext) + kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(128, GGML_PAD(llama_kv_cache_cell_max(kv_self), 128))); //kv_self.n = llama_kv_cache_cell_max(kv_self); //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index ff207e21b8ec3..e23384eee27c2 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2210,7 +2210,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_leaky_relu()); #if 1 - for (int hs : { 64, 80, 128, }) { + for (int hs : { 128, 64, 80, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, 2048, 4096, }) { for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) { From 9240a84c73f4666bdd738e9f98fc3e7e35b3c6ef Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Thu, 1 Feb 2024 11:07:10 -0500 Subject: [PATCH 51/58] fix mask nullptr --- ggml-cuda.cu | 6 ++++-- tests/test-flash-attention.cpp | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index e7bf95bd1d616..1095b1914d58c 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6566,13 +6566,15 @@ static __global__ void flash_attn_ext_f16( for (int64_t j = 0; j < Q16; ++j) { half16x16_a mqka; half16x16_acc mm; + if(mp) { + nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major); + } // convert accumulator to matrix_a nvcuda::wmma::store_matrix_sync( ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); nvcuda::wmma::load_matrix_sync (mqka, ss + 16*j*T + 16*cc, T); - nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major); - nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mm); + nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mp ? mm : zr); nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); } } diff --git a/tests/test-flash-attention.cpp b/tests/test-flash-attention.cpp index 1f779b0d4f08a..a60bdeb7385ed 100644 --- a/tests/test-flash-attention.cpp +++ b/tests/test-flash-attention.cpp @@ -248,7 +248,7 @@ struct ggml_tensor* compute_graph(const test_model & model, ggml_backend_t backe callback_userdata ud { true, - 1e-7, + 5e-4, model.backend, backend_cpu }; From 8d7a6066991a1e84a29b26a81c1c4ea576c4cefd Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Thu, 1 Feb 2024 11:29:50 -0500 Subject: [PATCH 52/58] don't require LLAMA_CUDA_F16 to compile --- ggml-cuda.cu | 43 +++++++++++-------------------------------- 1 file changed, 11 insertions(+), 32 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 1095b1914d58c..15b8fd78459f0 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -661,12 +661,17 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL } -static __device__ __forceinline__ __half warp_reduce_sum(__half x) { +static __device__ __forceinline__ half warp_reduce_sum(half x) { +#ifdef __CUDA_ARCH__ >= CC_VOLTA #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { - x += __shfl_xor_sync(0xffffffff, x, mask, 32); + x = __hadd(__shfl_xor_sync(0xffffffff, x, mask, 32), x); } return x; +#else + (void) x; + NO_DEVICE_CODE; +#endif } static __device__ __forceinline__ float warp_reduce_max(float x) { @@ -6403,6 +6408,7 @@ typedef nvcuda::wmma::fragment half16x16_b; typedef nvcuda::wmma::fragment half16x16_bT; typedef nvcuda::wmma::fragment half16x16_acc; +#endif // based on metal version template // D head size, Q queries per block, C cache items per block @@ -6433,6 +6439,7 @@ static __global__ void flash_attn_ext_f16( int ne1, int ne2, int ne3) { +#if __CUDA_ARCH__ >= CC_VOLTA const int warp_id = threadIdx.y; const int lane_id = threadIdx.x; @@ -6791,38 +6798,10 @@ static __global__ void flash_attn_ext_f16( } } } -} #else -template // D head size, Q queries per block, C cache items per blocks -static __global__ void flash_attn_ext_f16( - const char* __restrict__ q, - const char* __restrict__ k, - const char* __restrict__ v, - const char* __restrict__ mask, - float* __restrict__ kqv, - float scale, - int ne00, - int ne01, - int ne02, - int ne03, - int ne10, - int ne11, - int ne12, - int ne13, - int ne31, - int nb31, - int nb01, - int nb02, - int nb03, - int nb11, - int nb12, - int nb13, - int ne0, - int ne1, - int ne2, - int ne3) { - } + NO_DEVICE_CODE; #endif +} template static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, From 19e0b8eab38080d550026c9fb6d95cc40d875da7 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Thu, 1 Feb 2024 13:02:33 -0500 Subject: [PATCH 53/58] #ifdef -> #if + fix check -inf --- CMakeLists.txt | 3 ++- ggml-cuda.cu | 16 ++++++++-------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 15a1101aa415e..477b2b5eb8dfd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -371,6 +371,7 @@ if (LLAMA_CUBLAS) #set(CMAKE_CUDA_ARCHITECTURES "") # use this to compile much faster, but only F16 models work endif() endif() + set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};-lineinfo") message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") else() @@ -729,7 +730,7 @@ endif() set(CUDA_CXX_FLAGS "") if (LLAMA_CUBLAS) - set(CUDA_FLAGS ${CXX_FLAGS} -use_fast_math) + set(CUDA_FLAGS ${CXX_FLAGS} -use_fast_math -lineinfo) if (NOT MSVC) list(APPEND CUDA_FLAGS -Wno-pedantic) endif() diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 15b8fd78459f0..572c8c5ae21ab 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -662,7 +662,7 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { } static __device__ __forceinline__ half warp_reduce_sum(half x) { -#ifdef __CUDA_ARCH__ >= CC_VOLTA +#if __CUDA_ARCH__ >= CC_VOLTA #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { x = __hadd(__shfl_xor_sync(0xffffffff, x, mask, 32), x); @@ -6601,8 +6601,8 @@ static __global__ void flash_attn_ext_f16( smax = warp_reduce_max(__hmax(smax, s)); M[j] = warp_reduce_max(__hmax(M[j], s)); - const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]); - const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]); + const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]); + const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); S[j] = S[j]*ms + warp_reduce_sum(vs); @@ -6628,7 +6628,7 @@ static __global__ void flash_attn_ext_f16( smax = warp_reduce_max(smax); M[j] = warp_reduce_max(M[j]); - const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]); + const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]); // create a QxQ diagonal matrix for rescaling the output if (lane_id == j) { @@ -6641,7 +6641,7 @@ static __global__ void flash_attn_ext_f16( for (int64_t p = lane_id; p < C; p += NW) { const half s = ss[j*T + p]; - const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]); + const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); ls += vs; @@ -6654,7 +6654,7 @@ static __global__ void flash_attn_ext_f16( } // skip -INF blocks - if (__hisinf(smax)) { + if (__hisinf(smax) == -1) { continue; } @@ -6740,8 +6740,8 @@ static __global__ void flash_attn_ext_f16( M = __hmax(M0, M1); - const half ms0 = __hisinf(M0) ? __float2half(0.0f) : hexp(M0 - M); - const half ms1 = __hisinf(M1) ? __float2half(0.0f) : hexp(M1 - M); + const half ms0 = __hisinf(M0) == -1 ? __float2half(0.0f) : hexp(M0 - M); + const half ms1 = __hisinf(M1) == -1 ? __float2half(0.0f) : hexp(M1 - M); S = S0*ms0 + S1*ms1; From cae985cfb7791ccea4052e931ee8276cf2b5b70d Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Thu, 1 Feb 2024 13:05:17 -0500 Subject: [PATCH 54/58] cmake: remove unused changes --- CMakeLists.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 477b2b5eb8dfd..15a1101aa415e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -371,7 +371,6 @@ if (LLAMA_CUBLAS) #set(CMAKE_CUDA_ARCHITECTURES "") # use this to compile much faster, but only F16 models work endif() endif() - set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};-lineinfo") message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") else() @@ -730,7 +729,7 @@ endif() set(CUDA_CXX_FLAGS "") if (LLAMA_CUBLAS) - set(CUDA_FLAGS ${CXX_FLAGS} -use_fast_math -lineinfo) + set(CUDA_FLAGS ${CXX_FLAGS} -use_fast_math) if (NOT MSVC) list(APPEND CUDA_FLAGS -Wno-pedantic) endif() From 53621e31ce16bc318081eef2955c277ba1b6d788 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Thu, 1 Feb 2024 17:23:17 -0500 Subject: [PATCH 55/58] refactor flash_attn function + improve tests --- ggml-cuda.cu | 105 +++++++++++---------------------- tests/test-flash-attention.cpp | 35 ++++++----- 2 files changed, 54 insertions(+), 86 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 572c8c5ae21ab..c6605fe95ee8c 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -10891,33 +10891,34 @@ inline void ggml_cuda_flash_attn(const ggml_tensor * Q, const ggml_tensor * K, c } -inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, const ggml_tensor * mask, ggml_tensor * KQV) { - GGML_ASSERT(Q->type == GGML_TYPE_F32); - GGML_ASSERT(K->type == GGML_TYPE_F16); - GGML_ASSERT(V->type == GGML_TYPE_F16); - GGML_ASSERT(KQV->type == GGML_TYPE_F32); +inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { + GGML_TENSOR_BINARY_OP_LOCALS + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F16); + GGML_ASSERT(src2->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(Q->backend == GGML_BACKEND_GPU); - GGML_ASSERT(K->backend == GGML_BACKEND_GPU); - GGML_ASSERT(V->backend == GGML_BACKEND_GPU); - GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU); + GGML_ASSERT(src0->backend == GGML_BACKEND_GPU); + GGML_ASSERT(src1->backend == GGML_BACKEND_GPU); + GGML_ASSERT(src2->backend == GGML_BACKEND_GPU); + GGML_ASSERT(dst->backend == GGML_BACKEND_GPU); - GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); - GGML_ASSERT(!mask || mask->backend == GGML_BACKEND_GPU); - GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && + GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16); + GGML_ASSERT(!src3 || src3->backend == GGML_BACKEND_GPU); + GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(ne01, 16) && "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); ggml_cuda_set_device(g_main_device); const cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; - ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) Q->extra; - ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) K->extra; - ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) V->extra; - ggml_tensor_extra_gpu * src3_extra = mask ? (ggml_tensor_extra_gpu *) mask->extra : nullptr; - ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) KQV->extra; + const char * query_data = (const char *) ((ggml_tensor_extra_gpu *) src0->extra)->data_device[g_main_device]; + const char * key_data = (const char *) ((ggml_tensor_extra_gpu *) src1->extra)->data_device[g_main_device]; + const char * value_data = (const char *) ((ggml_tensor_extra_gpu *) src2->extra)->data_device[g_main_device]; + const char * mask_data = src3 ? (const char *) ((ggml_tensor_extra_gpu *) src3->extra)->data_device[g_main_device] : nullptr; + float * qkv_data = (float *) ((ggml_tensor_extra_gpu *) dst->extra)->data_device[g_main_device]; float scale; - memcpy(&scale, KQV->op_params, sizeof(float)); + memcpy(&scale, dst->op_params, sizeof(float)); #define NQPB 16 #define NCPW 128 @@ -10927,81 +10928,45 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? // TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why - const int nwarps = Q->ne[1] <= nqpb ? MAX(2, MIN(K->ne[1]/ncpw, nwarps_max)) : 2; + const int nwarps = ne01 <= nqpb ? MAX(2, MIN(ne11/ncpw, nwarps_max)) : 2; - dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); + dim3 blocks_num((ne01 + nqpb - 1) / nqpb, ne02, ne03); dim3 block_dim(32, nwarps, 1); - const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); + const size_t shmem = nqpb*(ne00 + nwarps*(ncpw + nqpb))*(sizeof(float)/2); - switch (Q->ne[0]) + switch (ne00) { case 16: flash_attn_ext_f16<16, NQPB, NCPW> <<>> ( - (const char *) src0_extra->data_device[g_main_device], // Query - (const char *) src1_extra->data_device[g_main_device], // Key - (const char *) src2_extra->data_device[g_main_device], // Value - mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask - (float *) dst_extra->data_device[g_main_device], // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + query_data, key_data, value_data, mask_data, qkv_data, scale, + ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, src3->ne[1], + src3->nb[1], nb01, nb02, nb03, nb11, nb12, nb13, ne0, ne1, ne2, ne3 ); break; case 64: flash_attn_ext_f16<64, NQPB, NCPW> <<>> ( - (const char *) src0_extra->data_device[g_main_device], // Query - (const char *) src1_extra->data_device[g_main_device], // Key - (const char *) src2_extra->data_device[g_main_device], // Value - mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask - (float *) dst_extra->data_device[g_main_device], // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + query_data, key_data, value_data, mask_data, qkv_data, scale, + ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, src3->ne[1], + src3->nb[1], nb01, nb02, nb03, nb11, nb12, nb13, ne0, ne1, ne2, ne3 ); break; case 80: flash_attn_ext_f16<80, NQPB, NCPW> <<>> ( - (const char *) src0_extra->data_device[g_main_device], // Query - (const char *) src1_extra->data_device[g_main_device], // Key - (const char *) src2_extra->data_device[g_main_device], // Value - mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask - (float *) dst_extra->data_device[g_main_device], // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + query_data, key_data, value_data, mask_data, qkv_data, scale, + ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, src3->ne[1], + src3->nb[1], nb01, nb02, nb03, nb11, nb12, nb13, ne0, ne1, ne2, ne3 ); break; case 128: flash_attn_ext_f16<128, NQPB, NCPW> <<>> ( - (const char *) src0_extra->data_device[g_main_device], // Query - (const char *) src1_extra->data_device[g_main_device], // Key - (const char *) src2_extra->data_device[g_main_device], // Value - mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask - (float *) dst_extra->data_device[g_main_device], // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + query_data, key_data, value_data, mask_data, qkv_data, scale, + ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, src3->ne[1], + src3->nb[1], nb01, nb02, nb03, nb11, nb12, nb13, ne0, ne1, ne2, ne3 ); break; default: diff --git a/tests/test-flash-attention.cpp b/tests/test-flash-attention.cpp index a60bdeb7385ed..bb99170d64646 100644 --- a/tests/test-flash-attention.cpp +++ b/tests/test-flash-attention.cpp @@ -105,7 +105,7 @@ void load_model(test_model & model, int head_dim, int batch_size, int kv_size, i float* query = new float[head_dim * batch_size * num_heads]; float* key = new float[head_dim * kv_size * num_heads]; float* value = new float[head_dim * kv_size * num_heads]; - float* mask = new float[kv_size * batch_size]; + float* mask = new float[kv_size * GGML_PAD(batch_size, GGML_KQ_MASK_PAD)]; for(int i = 0; i < head_dim*batch_size*num_heads;i ++) { query[i] = i % 3 ? 2.0f : 1.5f; @@ -116,8 +116,8 @@ void load_model(test_model & model, int head_dim, int batch_size, int kv_size, i value[i] = i % 3 ? 3.5f : 1.5f; } - for(int i = 0; i < batch_size*kv_size;i ++) { - mask[i] = i % 3 ? 1.0f : 0.0f; + for(int i = 0; i < GGML_PAD(batch_size, GGML_KQ_MASK_PAD)*kv_size;i ++) { + mask[i] = i % 3 ? 1.0f : 1.5f; } size_t buffer_size = 0; @@ -125,7 +125,7 @@ void load_model(test_model & model, int head_dim, int batch_size, int kv_size, i buffer_size += head_dim * batch_size * num_heads * ggml_type_sizef(GGML_TYPE_F32); // tensor q buffer_size += head_dim * kv_size * num_heads * ggml_type_sizef(GGML_TYPE_F16); // tensor k buffer_size += head_dim * kv_size * num_heads * ggml_type_sizef(GGML_TYPE_F16); // tensor v - buffer_size += batch_size * kv_size * ggml_type_sizef(GGML_TYPE_F32); // tensor mask + buffer_size += GGML_PAD(batch_size, GGML_KQ_MASK_PAD) * kv_size * ggml_type_sizef(GGML_TYPE_F16); // tensor mask buffer_size += 1024; } @@ -162,7 +162,7 @@ void load_model(test_model & model, int head_dim, int batch_size, int kv_size, i model.q = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, head_dim, batch_size, num_heads); model.k = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F16, head_dim, kv_size, num_heads); model.v = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F16, head_dim, kv_size, num_heads); - model.msk = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, kv_size, batch_size); + model.msk = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F16, kv_size, GGML_PAD(batch_size, GGML_KQ_MASK_PAD)); // create a allocator ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer); @@ -175,14 +175,16 @@ void load_model(test_model & model, int head_dim, int batch_size, int kv_size, i ggml_fp16_t* k_f16 = new ggml_fp16_t[head_dim * kv_size * num_heads]; ggml_fp16_t* v_f16 = new ggml_fp16_t[head_dim * kv_size * num_heads]; + ggml_fp16_t* m_f16 = new ggml_fp16_t[GGML_PAD(batch_size, GGML_KQ_MASK_PAD) * kv_size]; ggml_fp32_to_fp16_row(key, k_f16, head_dim * kv_size * num_heads); ggml_fp32_to_fp16_row(value, v_f16, head_dim * kv_size * num_heads); + ggml_fp32_to_fp16_row(mask, m_f16, GGML_PAD(batch_size, GGML_KQ_MASK_PAD) * kv_size); ggml_backend_tensor_set(model.q, query, 0, ggml_nbytes(model.q)); ggml_backend_tensor_set(model.k, k_f16, 0, ggml_nbytes(model.k)); ggml_backend_tensor_set(model.v, v_f16, 0, ggml_nbytes(model.v)); - ggml_backend_tensor_set(model.msk, mask, 0, ggml_nbytes(model.msk)); + ggml_backend_tensor_set(model.msk, m_f16, 0, ggml_nbytes(model.msk)); } struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * allocr) { @@ -201,12 +203,11 @@ struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * a struct ggml_cgraph * gf = ggml_new_graph(ctx0); if(!model.naive_attn) { - struct ggml_tensor* result = ggml_flash_attn_ext(ctx0, model.q, model.k, model.v, nullptr, 1.0f / sqrtf(model.q->ne[0])); + struct ggml_tensor* result = ggml_flash_attn_ext(ctx0, model.q, model.k, model.v, model.msk, 1.0f / sqrtf(model.q->ne[0])); ggml_build_forward_expand(gf, result); } else { struct ggml_tensor* kq = ggml_mul_mat(ctx0, model.k, model.q); - kq = ggml_scale_inplace(ctx0, kq, 1.0f / sqrtf((float)model.q->ne[0])); - kq = ggml_soft_max(ctx0, kq); + kq = ggml_soft_max_ext(ctx0, kq, model.msk, 1.0f / sqrtf(model.q->ne[0])); kq = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, model.v)), kq); kq = ggml_permute (ctx0, kq, 0, 2, 1, 3); //kq = ggml_cont_2d (ctx0, kq, model.q->ne[0] * model.q->ne[2], model.q->ne[1]); @@ -226,7 +227,7 @@ struct ggml_tensor* compute_graph(const test_model & model, ggml_backend_t backe // allocate tensors ggml_allocr_alloc_graph(allocr, gf); - int n_threads = 1; + int n_threads = 6; if (ggml_backend_is_cpu(model.backend)) { ggml_backend_cpu_set_n_threads(model.backend, n_threads); @@ -335,7 +336,8 @@ int main(int argc, char ** argv) ggml_time_init(); - load_model(model, 16, 32, 128, 2); + //load_model(model, 16, 32, 128, 2); + load_model(model, 64, 2048, 4096, 32); ggml_backend_buffer_t buf_compute; // for compute struct ggml_allocr * allocr = NULL; @@ -361,15 +363,16 @@ int main(int argc, char ** argv) ggml_backend_synchronize(model.backend); printf("computing time: %.4f ms\n", (ggml_time_us() - compute_time_us__) / 1000.0); float* data = new float[ggml_nelements(result)]; - ggml_backend_tensor_get(result, data, 0, ggml_nbytes(result)); - printf("\nPerforming test:\n"); + printf("\nPerforming test (%zu):\n", ggml_nelements(result)); + + int elements = ggml_nelements(result) > 1024 ? 1024 : ggml_nelements(result); - for(int i = 0; i < ggml_nelements(result); i ++) { - if(i > 0 && (i % result->ne[0] == 0)) { + for(int i = 0; i < elements; i ++) { + if(i > 0 && (i % 16 == 0)) { printf("\n"); } - if(i > 0 && (i % (result->ne[0] * result->ne[2]) == 0)) { + if(i > 0 && (i % (16 * 32) == 0)) { printf("\n\n"); } printf("%2.4f ", data[i]); From 674d5ac72d8d7233ee6df57f8f521d7445e3778c Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Sat, 3 Feb 2024 11:11:17 +0100 Subject: [PATCH 56/58] =?UTF-8?q?unroll=202=20loops,=20int64=5Ft=20->=20in?= =?UTF-8?q?t,=20309=20=C2=B5s?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ggml-cuda.cu | 82 ++++++++++++++++++++++++++++++---------------------- 1 file changed, 47 insertions(+), 35 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c6605fe95ee8c..b811aefe8d63b 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6467,10 +6467,22 @@ static __global__ void flash_attn_ext_f16( half16x16_acc lo[Q16][D16]; // load heads from Q to shared memory - for (int64_t j = warp_id; j < Q; j += num_warps) { +#pragma unroll + for (int j0 = 0; j0 < Q; j0 += num_warps) { + const int j = j0 + warp_id; + if (j >= Q) { + break; + } + const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); - for (int64_t i = lane_id; i < D2; i += NW) { +#pragma unroll + for (int i0 = 0; i0 < D2; i0 += NW) { + const int i = i0 + lane_id; + if (i >= D2) { + break; + } + if (iq1 + j < ne01) { sq2[j*T2 + i] = __float22half2_rn(q2[i]); } else { @@ -6482,15 +6494,15 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::fill_fragment(zr, 0.0); // zero out lo - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::fill_fragment(lo[j][i], 0.0); } } // zero out shared memory SH - for (int64_t j = 0; j < Q; ++j) { - for (int64_t i = lane_id; i < SH; i += NW) { + for (int j = 0; j < Q; ++j) { + for (int i = lane_id; i < SH; i += NW) { ss[j*T + i] = 0.0; } } @@ -6531,8 +6543,8 @@ static __global__ void flash_attn_ext_f16( // load the queries from shared memory into local memory half16x16_a mq[Q16][D16]; - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T); } } @@ -6549,28 +6561,28 @@ static __global__ void flash_attn_ext_f16( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) { + for (int ic = C*warp_id; ic < ne11; ic += C*num_warps) { // Q*K^T { for (int cc = 0; cc < C/16; ++cc) { half16x16_acc mqk[Q16]; - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { nvcuda::wmma::fill_fragment(mqk[j], 0); } const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13)); - for (int64_t i = 0; i < D16; ++i) { + for (int i = 0; i < D16; ++i) { half16x16_bT mk; // transposed key nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]); } } // mqk = mqk*scale + mask - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { half16x16_a mqka; half16x16_acc mm; if(mp) { @@ -6592,8 +6604,8 @@ static __global__ void flash_attn_ext_f16( // online softmax if (C == 32) { - for (int64_t j = 0; j < Q; ++j) { - const int64_t p = lane_id; + for (int j = 0; j < Q; ++j) { + const int p = lane_id; const half m = M[j]; const half s = ss[j*T + p]; @@ -6615,10 +6627,10 @@ static __global__ void flash_attn_ext_f16( ss[j*T + p] = vs; } } else { - for (int64_t j = 0; j < Q; ++j) { + for (int j = 0; j < Q; ++j) { const half m = M[j]; - for (int64_t p = lane_id; p < C; p += NW) { + for (int p = lane_id; p < C; p += NW) { const half s = ss[j*T + p]; smax = __hmax(smax, s); @@ -6638,7 +6650,7 @@ static __global__ void flash_attn_ext_f16( // local sum half ls = 0.0f; - for (int64_t p = lane_id; p < C; p += NW) { + for (int p = lane_id; p < C; p += NW) { const half s = ss[j*T + p]; const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); @@ -6659,13 +6671,13 @@ static __global__ void flash_attn_ext_f16( } // O = diag(ms)*O - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { half16x16_a mm; half16x16_b lob; nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); - for (int64_t i = 0; i < D16; ++i) { + for (int i = 0; i < D16; ++i) { // convert accumulator to matrix_b nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major); nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T); @@ -6684,17 +6696,17 @@ static __global__ void flash_attn_ext_f16( const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); half16x16_b mk[D16]; - for (int64_t i = 0; i < D16; ++i) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half)); } half16x16_a mv[Q16]; - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T); } - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]); } } @@ -6703,7 +6715,7 @@ static __global__ void flash_attn_ext_f16( } // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - for (int64_t j = 0; j < Q; ++j) { + for (int j = 0; j < Q; ++j) { if (lane_id == 0) { ss[j*T + 0] = S[j]; ss[j*T + 1] = M[j]; @@ -6712,7 +6724,7 @@ static __global__ void flash_attn_ext_f16( } // reduce the warps sequentially - for (int64_t sg = 1; sg < num_warps; ++sg) { + for (int sg = 1; sg < num_warps; ++sg) { half S = __float2half(0.0f); half M = __float2half(-INFINITY); @@ -6720,8 +6732,8 @@ static __global__ void flash_attn_ext_f16( // each simdgroup stores its output to shared memory, reusing sq if (warp_id == sg) { - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); } } @@ -6731,7 +6743,7 @@ static __global__ void flash_attn_ext_f16( // the first simdgroup accumulates the results from the other simdgroups if (warp_id == 0) { - for (int64_t j = 0; j < Q; ++j) { + for (int j = 0; j < Q; ++j) { const half S0 = ss[j*T + 0]; const half S1 = ss[j*T + sg*SH + 0]; @@ -6755,7 +6767,7 @@ static __global__ void flash_attn_ext_f16( } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { half16x16_a ms0; half16x16_a ms1; half16x16_b t; @@ -6764,7 +6776,7 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T); nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); - for (int64_t i = 0; i < D16; ++i) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::fill_fragment(t2, 0.0); nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); nvcuda::wmma::mma_sync(t2, ms1, t, t2); @@ -6781,8 +6793,8 @@ static __global__ void flash_attn_ext_f16( // store result to shared memory (reuse sq) if (warp_id == 0) { - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); } } @@ -6790,10 +6802,10 @@ static __global__ void flash_attn_ext_f16( // final rescale with 1/S and store to global memory if (warp_id == 0) { - for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { + for (int j = 0; j < Q && iq1 + j < ne01; ++j) { const half S = ss[j*T + 0]; - for (int64_t i = lane_id; i < D; i += NW) { + for (int i = lane_id; i < D; i += NW) { dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S); } } From a1f9ffe7b0551f59bc4651ed0934b5f701beea98 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Sat, 3 Feb 2024 11:35:36 -0500 Subject: [PATCH 57/58] bring optimizations from gg/flash-attn --- ggml-cuda.cu | 233 +++++++++++++++++---------------- tests/test-backend-ops.cpp | 8 +- tests/test-flash-attention.cpp | 23 ++-- 3 files changed, 138 insertions(+), 126 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c6605fe95ee8c..dbe51a42b05ca 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6418,7 +6418,7 @@ static __global__ void flash_attn_ext_f16( const char* __restrict__ v, const char* __restrict__ mask, float* __restrict__ dst, - float scale, + half scale, int ne00, int ne01, int ne02, @@ -6448,29 +6448,43 @@ static __global__ void flash_attn_ext_f16( const int iq2 = blockIdx.y; const int iq1 = blockIdx.x * Q; - const int D2 = D/2; + const int D2 = D/2; const int D16 = D/16; const int Q16 = Q/16; - const int NW = WARP_SIZE; - const int SH = (C + Q); // shared memory per simdgroup in (half) + const int NW = WARP_SIZE; + const int SH = (C + Q); // shared memory per simdgroup in (half) const int T = D + num_warps*SH; // shared memory size per query in (half) const int T2 = T/2; // shared memory size per query in (half2) + const int C2 = C/2; extern __shared__ half __flash_attn_f16_shmem[]; // pq half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2 half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix + half2 * ss2 = (half2 *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // same as above but in half2 half16x16_acc zr; half16x16_acc lo[Q16][D16]; // load heads from Q to shared memory - for (int64_t j = warp_id; j < Q; j += num_warps) { +#pragma unroll + for (int j0 = 0; j0 < Q; j0 += num_warps) { + const int j = j0 + warp_id; + if (j >= Q) { + break; + } + const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); - for (int64_t i = lane_id; i < D2; i += NW) { +#pragma unroll + for (int i0 = 0; i0 < D2; i0 += NW) { + const int i = i0 + lane_id; + if (i >= D2) { + break; + } + if (iq1 + j < ne01) { sq2[j*T2 + i] = __float22half2_rn(q2[i]); } else { @@ -6482,15 +6496,20 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::fill_fragment(zr, 0.0); // zero out lo - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::fill_fragment(lo[j][i], 0.0); } } // zero out shared memory SH - for (int64_t j = 0; j < Q; ++j) { - for (int64_t i = lane_id; i < SH; i += NW) { + for (int j = 0; j < Q; ++j) { + for (int i0 = 0; i0 < SH; i0 += NW) { + const int i = i0 + lane_id; + if (i >= SH) { + break; + } + ss[j*T + i] = 0.0; } } @@ -6501,7 +6520,7 @@ static __global__ void flash_attn_ext_f16( half S[Q]; half M[Q]; - for(int i = 0; i < Q; i++) { + for (int i = 0; i < Q; ++i) { S[i] = __float2half(0.0f); M[i] = __float2half(-INFINITY); } @@ -6531,8 +6550,8 @@ static __global__ void flash_attn_ext_f16( // load the queries from shared memory into local memory half16x16_a mq[Q16][D16]; - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T); } } @@ -6543,37 +6562,43 @@ static __global__ void flash_attn_ext_f16( // prepare diagonal scale matrix half16x16_b mscale; for (int i = 0; i < 16; ++i) { - ss[i*T + i] = __float2half(scale); + ss[i*T + i] = scale; } nvcuda::wmma::load_matrix_sync(mscale, ss, T); // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) { + for (int ic0 = 0; ic0 < ne11; ic0 += C*num_warps) { + const int ic = ic0 + warp_id*C; + if (ic >= ne11) { + break; + } + // Q*K^T { for (int cc = 0; cc < C/16; ++cc) { half16x16_acc mqk[Q16]; - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { nvcuda::wmma::fill_fragment(mqk[j], 0); } const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13)); - for (int64_t i = 0; i < D16; ++i) { + for (int i = 0; i < D16; ++i) { half16x16_bT mk; // transposed key nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]); } } // mqk = mqk*scale + mask - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { half16x16_a mqka; half16x16_acc mm; - if(mp) { + + if (mp) { nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major); } @@ -6588,90 +6613,72 @@ static __global__ void flash_attn_ext_f16( } // used to detect blocks full of -INF - half smax = __float2half(-INFINITY); + half2 smax = make_half2(-INFINITY, -INFINITY); // online softmax - if (C == 32) { - for (int64_t j = 0; j < Q; ++j) { - const int64_t p = lane_id; + for (int j = 0; j < Q; ++j) { + const half m = M[j]; - const half m = M[j]; - const half s = ss[j*T + p]; + for (int p0 = 0; p0 < C2; p0 += NW) { + const int p = p0 + lane_id; - smax = warp_reduce_max(__hmax(smax, s)); - M[j] = warp_reduce_max(__hmax(M[j], s)); - - const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]); - const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); - - S[j] = S[j]*ms + warp_reduce_sum(vs); - - // create a QxQ diagonal matrix for rescaling the output - if (p == j) { - ss[j*T + C + j] = ms; - } + const half2 s = ss2[j*T2 + p]; - // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; + smax = __hmax2(smax, s); + M[j] = __hmax(M[j], __hmax(s.x, s.y)); } - } else { - for (int64_t j = 0; j < Q; ++j) { - const half m = M[j]; - for (int64_t p = lane_id; p < C; p += NW) { - const half s = ss[j*T + p]; + M[j] = warp_reduce_max(M[j]); - smax = __hmax(smax, s); - M[j] = __hmax(M[j], s); - } + const half ms = hexp(m - M[j]); - smax = warp_reduce_max(smax); - M[j] = warp_reduce_max(M[j]); + // create a QxQ diagonal matrix for rescaling the output + if (lane_id == j) { + ss[j*T + C + j] = ms; + } - const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]); + // local sum + half2 ls = make_half2(0.0f, 0.0f); + half2 M2 = make_half2(M[j], M[j]); - // create a QxQ diagonal matrix for rescaling the output - if (lane_id == j) { - ss[j*T + C + j] = ms; - } + for (int p0 = 0; p0 < C2; p0 += NW) { + const int p = p0 + lane_id; - // local sum - half ls = 0.0f; + const half2 s = ss2[j*T2 + p]; - for (int64_t p = lane_id; p < C; p += NW) { - const half s = ss[j*T + p]; + const half2 vs = h2exp(s - M2); - const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); + ls += vs; - ls += vs; + // the P matrix from the paper (Q rows, C columns) + ss2[j*T2 + p] = vs; + } - // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; - } + ls = warp_reduce_sum(ls); - S[j] = S[j]*ms + warp_reduce_sum(ls); - } + S[j] = S[j]*ms + ls.x + ls.y; } + smax = warp_reduce_max(smax); + // skip -INF blocks - if (__hisinf(smax) == -1) { + if (__hisinf(smax.x) == -1 || __hisinf(smax.y) == -1) { continue; } // O = diag(ms)*O - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { half16x16_a mm; half16x16_b lob; nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); - for (int64_t i = 0; i < D16; ++i) { + for (int i = 0; i < D16; ++i) { // convert accumulator to matrix_b nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major); nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T); - nvcuda::wmma::fill_fragment(lo[j][i], 0.0); - nvcuda::wmma::mma_sync(lo[j][i], mm, lob, lo[j][i]); + nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr); } // restore zeros @@ -6684,17 +6691,17 @@ static __global__ void flash_attn_ext_f16( const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); half16x16_b mk[D16]; - for (int64_t i = 0; i < D16; ++i) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half)); } half16x16_a mv[Q16]; - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T); } - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]); } } @@ -6703,7 +6710,7 @@ static __global__ void flash_attn_ext_f16( } // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - for (int64_t j = 0; j < Q; ++j) { + for (int j = 0; j < Q; ++j) { if (lane_id == 0) { ss[j*T + 0] = S[j]; ss[j*T + 1] = M[j]; @@ -6712,16 +6719,13 @@ static __global__ void flash_attn_ext_f16( } // reduce the warps sequentially - for (int64_t sg = 1; sg < num_warps; ++sg) { - half S = __float2half(0.0f); - half M = __float2half(-INFINITY); - + for (int sg = 1; sg < num_warps; ++sg) { __syncthreads(); // each simdgroup stores its output to shared memory, reusing sq if (warp_id == sg) { - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); } } @@ -6731,31 +6735,29 @@ static __global__ void flash_attn_ext_f16( // the first simdgroup accumulates the results from the other simdgroups if (warp_id == 0) { - for (int64_t j = 0; j < Q; ++j) { + for (int j = lane_id; j < Q; j += NW) { const half S0 = ss[j*T + 0]; const half S1 = ss[j*T + sg*SH + 0]; const half M0 = ss[j*T + 1]; const half M1 = ss[j*T + sg*SH + 1]; - M = __hmax(M0, M1); + const half M = __hmax(M0, M1); - const half ms0 = __hisinf(M0) == -1 ? __float2half(0.0f) : hexp(M0 - M); - const half ms1 = __hisinf(M1) == -1 ? __float2half(0.0f) : hexp(M1 - M); + const half ms0 = hexp(M0 - M); + const half ms1 = hexp(M1 - M); - S = S0*ms0 + S1*ms1; + const half S = S0*ms0 + S1*ms1; - if (lane_id == 0) { - ss[j*T + 0] = S; - ss[j*T + 1] = M; + ss[j*T + 0] = S; + ss[j*T + 1] = M; - ss[j*T + C + j ] = ms0; - ss[j*T + C + j + sg*SH] = ms1; - } + ss[j*T + C + j ] = ms0; + ss[j*T + C + j + sg*SH] = ms1; } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { half16x16_a ms0; half16x16_a ms1; half16x16_b t; @@ -6764,10 +6766,9 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T); nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); - for (int64_t i = 0; i < D16; ++i) { - nvcuda::wmma::fill_fragment(t2, 0.0); + for (int i = 0; i < D16; ++i) { nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); - nvcuda::wmma::mma_sync(t2, ms1, t, t2); + nvcuda::wmma::mma_sync(t2, ms1, t, zr); // convert accumulator to matrix_b nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); @@ -6781,8 +6782,8 @@ static __global__ void flash_attn_ext_f16( // store result to shared memory (reuse sq) if (warp_id == 0) { - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); } } @@ -6790,10 +6791,15 @@ static __global__ void flash_attn_ext_f16( // final rescale with 1/S and store to global memory if (warp_id == 0) { - for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { + for (int j = 0; j < Q && iq1 + j < ne01; ++j) { const half S = ss[j*T + 0]; - for (int64_t i = lane_id; i < D; i += NW) { + for (int i0 = 0; i0 < D; i0 += NW) { + const int i = i0 + lane_id; + if (i >= D) { + break; + } + dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S); } } @@ -10917,8 +10923,9 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor const char * mask_data = src3 ? (const char *) ((ggml_tensor_extra_gpu *) src3->extra)->data_device[g_main_device] : nullptr; float * qkv_data = (float *) ((ggml_tensor_extra_gpu *) dst->extra)->data_device[g_main_device]; - float scale; - memcpy(&scale, dst->op_params, sizeof(float)); + float scale_; + memcpy(&scale_, dst->op_params, sizeof(float)); + half scale = __float2half(scale_); #define NQPB 16 #define NCPW 128 @@ -10926,9 +10933,11 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor const int nqpb = NQPB; // queries per block const int ncpw = NCPW; // cache values per warp (does not work for other values) +GGML_ASSERT(NQPB <= 32); + const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? // TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why - const int nwarps = ne01 <= nqpb ? MAX(2, MIN(ne11/ncpw, nwarps_max)) : 2; + const int nwarps = ne01 <= nqpb ? MAX(2, MIN(ne11/ncpw, nwarps_max)) : 1; dim3 blocks_num((ne01 + nqpb - 1) / nqpb, ne02, ne03); dim3 block_dim(32, nwarps, 1); @@ -10941,32 +10950,32 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * src0, const ggml_tensor flash_attn_ext_f16<16, NQPB, NCPW> <<>> ( query_data, key_data, value_data, mask_data, qkv_data, scale, - ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, src3->ne[1], - src3->nb[1], nb01, nb02, nb03, nb11, nb12, nb13, ne0, ne1, ne2, ne3 + ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, src3 ? src3->ne[1] : 0, + src3 ? src3->nb[1] : 0, nb01, nb02, nb03, nb11, nb12, nb13, ne0, ne1, ne2, ne3 ); break; case 64: flash_attn_ext_f16<64, NQPB, NCPW> <<>> ( query_data, key_data, value_data, mask_data, qkv_data, scale, - ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, src3->ne[1], - src3->nb[1], nb01, nb02, nb03, nb11, nb12, nb13, ne0, ne1, ne2, ne3 + ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, src3 ? src3->ne[1] : 0, + src3 ? src3->nb[1] : 0, nb01, nb02, nb03, nb11, nb12, nb13, ne0, ne1, ne2, ne3 ); break; case 80: flash_attn_ext_f16<80, NQPB, NCPW> <<>> ( query_data, key_data, value_data, mask_data, qkv_data, scale, - ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, src3->ne[1], - src3->nb[1], nb01, nb02, nb03, nb11, nb12, nb13, ne0, ne1, ne2, ne3 + ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, src3 ? src3->ne[1] : 0, + src3 ? src3->nb[1] : 0, nb01, nb02, nb03, nb11, nb12, nb13, ne0, ne1, ne2, ne3 ); break; case 128: flash_attn_ext_f16<128, NQPB, NCPW> <<>> ( query_data, key_data, value_data, mask_data, qkv_data, scale, - ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, src3->ne[1], - src3->nb[1], nb01, nb02, nb03, nb11, nb12, nb13, ne0, ne1, ne2, ne3 + ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, src3 ? src3->ne[1] : 0, + src3 ? src3->nb[1] : 0, nb01, nb02, nb03, nb11, nb12, nb13, ne0, ne1, ne2, ne3 ); break; default: diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index e23384eee27c2..e4076b49c180d 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -572,13 +572,13 @@ struct test_case { // duplicate the op size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1; -#if 1 +#if 0 for (int i = 1; i < n_runs; i++) { gf->nodes[gf->n_nodes++] = out; } #else - n_runs = 1000; int n_nodes = gf->n_nodes; + n_runs = 1000; for (int i = 1; i < n_runs; i++) { for (int j = 0; j < n_nodes; j++) { gf->nodes[gf->n_nodes++] = gf->nodes[j]; @@ -2210,7 +2210,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_leaky_relu()); #if 1 - for (int hs : { 128, 64, 80, }) { + for (int hs : { 128, 64, 80, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, 2048, 4096, }) { for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) { @@ -2224,7 +2224,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op for (int hs : { 128, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, }) { - for (int nb : { 1, 2, 4, 7, 8, 15, 16, 512 }) { + for (int nb : { 1, 2, 4, 8, 512 }) { test_cases.emplace_back(new test_attn (hs, nh, kv, nb)); test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb)); } diff --git a/tests/test-flash-attention.cpp b/tests/test-flash-attention.cpp index bb99170d64646..5f700e6aadc23 100644 --- a/tests/test-flash-attention.cpp +++ b/tests/test-flash-attention.cpp @@ -108,16 +108,19 @@ void load_model(test_model & model, int head_dim, int batch_size, int kv_size, i float* mask = new float[kv_size * GGML_PAD(batch_size, GGML_KQ_MASK_PAD)]; for(int i = 0; i < head_dim*batch_size*num_heads;i ++) { - query[i] = i % 3 ? 2.0f : 1.5f; + float q = (1.0f * i) / (head_dim*batch_size*num_heads); + query[i] = q * 2.0f - 1.0f; } for(int i = 0; i < head_dim*kv_size*num_heads;i ++) { - key[i] = i % 3 ? 2.3f : 2.8f; - value[i] = i % 3 ? 3.5f : 1.5f; + float q = (1.0f * i) / (head_dim*kv_size*num_heads); + key[i] = -(q * 2.0f - 1.0f); + value[i] = (1.0f - q) * 2.0f - 1.0f; } for(int i = 0; i < GGML_PAD(batch_size, GGML_KQ_MASK_PAD)*kv_size;i ++) { - mask[i] = i % 3 ? 1.0f : 1.5f; + float q = (1.0f * i) / (GGML_PAD(batch_size, GGML_KQ_MASK_PAD)*kv_size); + mask[i] = q * 1.5f; } size_t buffer_size = 0; @@ -159,10 +162,10 @@ void load_model(test_model & model, int head_dim, int batch_size, int kv_size, i model.ctx = ggml_init(params); // create tensors - model.q = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, head_dim, batch_size, num_heads); - model.k = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F16, head_dim, kv_size, num_heads); - model.v = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F16, head_dim, kv_size, num_heads); - model.msk = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F16, kv_size, GGML_PAD(batch_size, GGML_KQ_MASK_PAD)); + model.q = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, head_dim, batch_size, num_heads, 1); + model.k = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F16, head_dim, kv_size, num_heads, 1); + model.v = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F16, head_dim, kv_size, num_heads, 1); + model.msk = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F16, kv_size, GGML_PAD(batch_size, GGML_KQ_MASK_PAD), 1, 1); // create a allocator ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer); @@ -211,7 +214,7 @@ struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * a kq = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, model.v)), kq); kq = ggml_permute (ctx0, kq, 0, 2, 1, 3); //kq = ggml_cont_2d (ctx0, kq, model.q->ne[0] * model.q->ne[2], model.q->ne[1]); - ggml_build_forward_expand(gf, kq); + ggml_build_forward_expand(gf, ggml_cont(ctx0, kq)); } // delete the temporally context used to build the graph @@ -337,7 +340,7 @@ int main(int argc, char ** argv) ggml_time_init(); //load_model(model, 16, 32, 128, 2); - load_model(model, 64, 2048, 4096, 32); + load_model(model, 64, 512, 128*128, 32); ggml_backend_buffer_t buf_compute; // for compute struct ggml_allocr * allocr = NULL; From f659f575e010f4291bc9ddd2914aa41c8a652975 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Sat, 3 Feb 2024 11:42:15 -0500 Subject: [PATCH 58/58] fix merge conflicts --- ggml-cuda.cu | 68 ---------------------------------------------------- 1 file changed, 68 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 6945d169a3770..dbe51a42b05ca 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6568,16 +6568,12 @@ static __global__ void flash_attn_ext_f16( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns -<<<<<<< HEAD for (int ic0 = 0; ic0 < ne11; ic0 += C*num_warps) { const int ic = ic0 + warp_id*C; if (ic >= ne11) { break; } -======= - for (int ic = C*warp_id; ic < ne11; ic += C*num_warps) { ->>>>>>> 8b51ab447b074dbabc007743613aa93e6a4c028e // Q*K^T { for (int cc = 0; cc < C/16; ++cc) { @@ -6620,14 +6616,8 @@ static __global__ void flash_attn_ext_f16( half2 smax = make_half2(-INFINITY, -INFINITY); // online softmax -<<<<<<< HEAD for (int j = 0; j < Q; ++j) { const half m = M[j]; -======= - if (C == 32) { - for (int j = 0; j < Q; ++j) { - const int p = lane_id; ->>>>>>> 8b51ab447b074dbabc007743613aa93e6a4c028e for (int p0 = 0; p0 < C2; p0 += NW) { const int p = p0 + lane_id; @@ -6663,50 +6653,10 @@ static __global__ void flash_attn_ext_f16( // the P matrix from the paper (Q rows, C columns) ss2[j*T2 + p] = vs; } -<<<<<<< HEAD ls = warp_reduce_sum(ls); S[j] = S[j]*ms + ls.x + ls.y; -======= - } else { - for (int j = 0; j < Q; ++j) { - const half m = M[j]; - - for (int p = lane_id; p < C; p += NW) { - const half s = ss[j*T + p]; - - smax = __hmax(smax, s); - M[j] = __hmax(M[j], s); - } - - smax = warp_reduce_max(smax); - M[j] = warp_reduce_max(M[j]); - - const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]); - - // create a QxQ diagonal matrix for rescaling the output - if (lane_id == j) { - ss[j*T + C + j] = ms; - } - - // local sum - half ls = 0.0f; - - for (int p = lane_id; p < C; p += NW) { - const half s = ss[j*T + p]; - - const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); - - ls += vs; - - // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; - } - - S[j] = S[j]*ms + warp_reduce_sum(ls); - } ->>>>>>> 8b51ab447b074dbabc007743613aa93e6a4c028e } smax = warp_reduce_max(smax); @@ -6770,12 +6720,6 @@ static __global__ void flash_attn_ext_f16( // reduce the warps sequentially for (int sg = 1; sg < num_warps; ++sg) { -<<<<<<< HEAD -======= - half S = __float2half(0.0f); - half M = __float2half(-INFINITY); - ->>>>>>> 8b51ab447b074dbabc007743613aa93e6a4c028e __syncthreads(); // each simdgroup stores its output to shared memory, reusing sq @@ -6791,11 +6735,7 @@ static __global__ void flash_attn_ext_f16( // the first simdgroup accumulates the results from the other simdgroups if (warp_id == 0) { -<<<<<<< HEAD for (int j = lane_id; j < Q; j += NW) { -======= - for (int j = 0; j < Q; ++j) { ->>>>>>> 8b51ab447b074dbabc007743613aa93e6a4c028e const half S0 = ss[j*T + 0]; const half S1 = ss[j*T + sg*SH + 0]; @@ -6827,10 +6767,6 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); for (int i = 0; i < D16; ++i) { -<<<<<<< HEAD -======= - nvcuda::wmma::fill_fragment(t2, 0.0); ->>>>>>> 8b51ab447b074dbabc007743613aa93e6a4c028e nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); nvcuda::wmma::mma_sync(t2, ms1, t, zr); @@ -6858,16 +6794,12 @@ static __global__ void flash_attn_ext_f16( for (int j = 0; j < Q && iq1 + j < ne01; ++j) { const half S = ss[j*T + 0]; -<<<<<<< HEAD for (int i0 = 0; i0 < D; i0 += NW) { const int i = i0 + lane_id; if (i >= D) { break; } -======= - for (int i = lane_id; i < D; i += NW) { ->>>>>>> 8b51ab447b074dbabc007743613aa93e6a4c028e dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S); } }