8000 llama : custom attention mask + parallel decoding + no context swaps by ggerganov · Pull Request #3228 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

llama : custom attention mask + parallel decoding + no context swaps #3228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 57 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
c5df72e
tests : verify that RoPE is "additive"
ggerganov Sep 17, 2023
3b4bab6
llama : replace ggml_diag_mask_inf with ggml_add (custom -inf mask)
ggerganov Sep 17, 2023
1fb033f
ggml : ggml_rope now takes a vector with positions instead of n_past
ggerganov Sep 17, 2023
fad5693
metal : add rope_f16 kernel + optimize cpy kernels
ggerganov Sep 17, 2023
d29e769
llama : unified KV cache + batch inference API
ggerganov Sep 18, 2023
58bb511
Merge branch 'master' into custom-attention-mask
ggerganov Sep 18, 2023
9f42e75
llama : add new llama_decode() API that works with llama_batch
ggerganov Sep 18, 2023
6952a46
llama : add cell_max heuristic for more efficient kv_cache
ggerganov Sep 18, 2023
4d76d76
llama : extend llama_kv_cache API
ggerganov Sep 18, 2023
f015b26
llama : more robust cell_max heuristic + wip shift
ggerganov Sep 18, 2023
86c90e3
metal : disable concurrency optimization
ggerganov Sep 18, 2023
0cbf3bf
llama : add llama_kv_cache_shift_seq + no more context swaps
ggerganov Sep 18, 2023
7c1bdd0
llama : apply K-cache roping for Falcon and Baichuan
ggerganov Sep 18, 2023
1f17ea6
speculative : fix KV cache management
ggerganov Sep 18, 2023
0161372
parallel : example for serving multiple users in parallel
ggerganov Sep 18, 2023
466b513
parallel : disable hot-plug to avoid cache fragmentation
ggerganov Sep 18, 2023
897cacc
fixes : speculative KV cache + llama worst-case graph
ggerganov Sep 18, 2023
fa0e677
llama : extend batch API to select which logits to output
ggerganov Sep 18, 2023
daf4c6d
llama : fix worst case graph build
ggerganov Sep 19, 2023
7e2b997
ggml-cuda : update rope implementation for parallel decoding (#3254)
slaren Sep 19, 2023
25bd254
make : add parallel to build + fix static functions in llama.cpp
ggerganov Sep 19, 2023
467e307
simple : fix token counting
ggerganov Sep 19, 2023
36714e1
parallel : various improvements
ggerganov Sep 19, 2023
ddad227
llama : fix cell_max logic + rename functions
ggerganov Sep 19, 2023
806d397
parallel : try smaller batches when the KV cache is fragmented
ggerganov Sep 19, 2023
16090a5
parallel : fix sequence termination criteria
ggerganov Sep 19, 2023
d37081a
llama : silence errors KV cache errors
ggerganov Sep 19, 2023
82e20e9
parallel : remove new line from prompt
ggerganov Sep 19, 2023
4b5f3cd
parallel : process system prompt once + configurable paramters + llam…
ggerganov Sep 19, 2023
8a9aca3
parallel : remove question with short answers
ggerganov Sep 19, 2023
eed3fd4
parallel : count cache misses
ggerganov Sep 19, 2023
6028879
parallel : print misses on each request
ggerganov Sep 19, 2023
7b7472e
parallel : minor
ggerganov Sep 19, 2023
e1067ef
llama : fix n_kv to never become 0
ggerganov Sep 20, 2023
a1327c7
parallel : rename hot-plug to continuous-batching
ggerganov Sep 20, 2023
addae65
llama : improve llama_batch API + simplify parallel example
ggerganov Sep 20, 2023
b377bf2
simple : add parallel decoding support
ggerganov Sep 20, 2023
db0fc2d
simple : improve comments + free batch
ggerganov Sep 20, 2023
e04dc51
ggml-cuda : add rope f16, restore performance with parallel decoding …
slaren Sep 20, 2023
5420696
llama : disable MPI for now
ggerganov Sep 20, 2023
2f3a46f
train : make KQ_pos memory buffer permanent via dummy scale op
ggerganov Sep 20, 2023
1be2b8c
ggml : revert change to ggml_cpy, add ggml_cont_Nd instead (#3275)
slaren Sep 20, 2023
ee1d670
parallel : fix bug (extra BOS) + smaller token_prev array
ggerganov Sep 20, 2023
ded9b43
parallel : fix cases where the input prompts can overflow the batch
ggerganov Sep 20, 2023
b2debf6
parallel : add disabled experimental batch chunking in powers of two
ggerganov Sep 20, 2023
5a3369d
llama : llama.h formatting + comments
ggerganov Sep 21, 2023
8845160
simple : add README.md
ggerganov Sep 21, 2023
c1596f6
llama : fix kv cache heuristic when context is less than 32
ggerganov Sep 27, 2023
2585690
Merge branch 'master' into custom-attention-mask
ggerganov Sep 28, 2023
4ad0676
parallel : fix crash when `-n -1`
ggerganov Sep 28, 2023
e946379
llama : simplify returns if/else branches
ggerganov Sep 28, 2023
4c72ab1
metal : use mm kernels for batch size > 2
ggerganov Sep 28, 2023
d008733
examples : utilize new llama_get_logits_ith()
ggerganov Sep 28, 2023
a207561
examples : add example for batched decoding
ggerganov Sep 28, 2023
2b8830a
examples : do not eval prompt 2 times (close #3348)
ggerganov Sep 28, 2023
ce2d995
server : clear the KV cache beyond n_past before llama_decode
ggerganov Sep 28, 2023
c5650ed
server : avoid context swaps by shifting the KV cache
ggerganov Sep 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
< 8000 div style="left: -107px; top: -7px" data-view-component="true" class="Popover-message Box position-relative mx-auto text-left color-shadow-large p-2 mt-2">
Diff view
Diff view
Prev Previous commit
Next Next commit
metal : add rope_f16 kernel + optimize cpy kernels
  • Loading branch information
ggerganov committed Sep 17, 2023
commit fad56936d484a48eede12f30d194a26e1ea9e6b1
36 changes: 23 additions & 13 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
GGML_METAL_DECL_KERNEL(rope);
GGML_METAL_DECL_KERNEL(rope_f32);
GGML_METAL_DECL_KERNEL(rope_f16);
GGML_METAL_DECL_KERNEL(alibi_f32);
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
Expand Down Expand Up @@ -261,7 +262,8 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
GGML_METAL_ADD_KERNEL(rope);
GGML_METAL_ADD_KERNEL(rope_f32);
GGML_METAL_ADD_KERNEL(rope_f16);
GGML_METAL_ADD_KERNEL(alibi_f32);
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
Expand Down Expand Up @@ -335,7 +337,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
GGML_METAL_DEL_KERNEL(rope);
GGML_METAL_DEL_KERNEL(rope_f32);
GGML_METAL_DEL_KERNEL(rope_f16);
GGML_METAL_DEL_KERNEL(alibi_f32);
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
Expand Down Expand Up @@ -870,7 +873,7 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_SOFT_MAX:
{
const int nth = 32;
const int nth = MIN(32, ne00);

if (ne00%4 == 0) {
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
Expand Down Expand Up @@ -1134,7 +1137,7 @@ void ggml_metal_graph_compute(
float eps;
memcpy(&eps, dst->op_params, sizeof(float));

const int nth = 512;
const int nth = MIN(512, ne00);

[encoder setComputePipelineState:ctx->pipeline_rms_norm];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
Expand All @@ -1153,7 +1156,7 @@ void ggml_metal_graph_compute(
float eps;
memcpy(&eps, dst->op_params, sizeof(float));

const int nth = 256;
const int nth = MIN(256, ne00);

[encoder setComputePipelineState:ctx->pipeline_norm];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
Expand All @@ -1171,6 +1174,8 @@ void ggml_metal_graph_compute(
{
GGML_ASSERT((src0t == GGML_TYPE_F32));

const int nth = MIN(1024, ne00);

const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
const int n_head = ((int32_t *) dst->op_params)[1];
float max_bias;
Expand Down Expand Up @@ -1204,15 +1209,15 @@ void ggml_metal_graph_compute(
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];

const int nth = 32;

[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_ROPE:
{
GGML_ASSERT(ne10 == ne02);

//const int n_past = ((int32_t *) dst->op_params)[0];
const int nth = MIN(1024, ne00);

const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];

Expand All @@ -1221,7 +1226,12 @@ void ggml_metal_graph_compute(
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));

[encoder setComputePipelineState:ctx->pipeline_rope];
switch (src0->type) {
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break;
default: GGML_ASSERT(false);
};

[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
Expand All @@ -1241,19 +1251,19 @@ void ggml_metal_graph_compute(
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
//[encoder setBytes:&n_past length:sizeof( int) atIndex:19];
[encoder setBytes:&n_past length:sizeof( int) atIndex:19];
[encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
[encoder setBytes:&mode length:sizeof( int) atIndex:21];
[encoder setBytes:&freq_base length:sizeof(float) atIndex:22];
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:23];

[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_DUP:
case GGML_OP_CPY:
case GGML_OP_CONT:
{
const int nth = 32;
const int nth = MIN(1024, ne00);

switch (src0t) {
case GGML_TYPE_F32:
Expand Down
45 changes: 39 additions & 6 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -853,6 +853,36 @@ kernel void kernel_alibi_f32(
}
}

typedef void (rope_t)(
device const void * src0,
device const int32_t * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
constant int & n_past,
constant int & n_dims,
constant int & mode,
constant float & freq_base,
constant float & freq_scale,
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg[[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]);

template<typename T>
kernel void kernel_rope(
device const void * src0,
device const int32_t * src1,
Expand Down Expand Up @@ -901,11 +931,11 @@ kernel void kernel_rope(
const float cos_theta = cos(theta);
const float sin_theta = sin(theta);

device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

const float x0 = src[0];
const float x1 = src[1];
const T x0 = src[0];
const T x1 = src[1];

dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[1] = x0*sin_theta + x1*cos_theta;
Expand All @@ -920,8 +950,8 @@ kernel void kernel_rope(

const int64_t i0 = ib*n_dims + ic/2;

device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

const float x0 = src[0];
const float x1 = src[n_dims/2];
Expand All @@ -933,6 +963,9 @@ kernel void kernel_rope(
}
}

template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;

kernel void kernel_cpy_f16_f16(
device const half * src0,
device half * dst,
Expand Down
0