8000 llama : custom attention mask + parallel decoding + no context swaps by ggerganov · Pull Request #3228 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

llama : custom attention mask + parallel decoding + no context swaps #3228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 57 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
c5df72e
tests : verify that RoPE is "additive"
ggerganov Sep 17, 2023
3b4bab6
llama : replace ggml_diag_mask_inf with ggml_add (custom -inf mask)
ggerganov Sep 17, 2023
1fb033f
ggml : ggml_rope now takes a vector with positions instead of n_past
ggerganov Sep 17, 2023
fad5693
metal : add rope_f16 kernel + optimize cpy kernels
ggerganov Sep 17, 2023
d29e769
llama : unified KV cache + batch inference API
ggerganov Sep 18, 2023
58bb511
Merge branch 'master' into custom-attention-mask
ggerganov Sep 18, 2023
9f42e75
llama : add new llama_decode() API that works with llama_batch
ggerganov Sep 18, 2023
6952a46
llama : add cell_max heuristic for more efficient kv_cache
ggerganov Sep 18, 2023
4d76d76
llama : extend llama_kv_cache API
ggerganov Sep 18, 2023
f015b26
llama : more robust cell_max heuristic + wip shift
ggerganov Sep 18, 2023
86c90e3
metal : disable concurrency optimization
ggerganov Sep 18, 2023
0cbf3bf
llama : add llama_kv_cache_shift_seq + no more context swaps
ggerganov Sep 18, 2023
7c1bdd0
llama : apply K-cache roping for Falcon and Baichuan
ggerganov Sep 18, 2023
1f17ea6
speculative : fix KV cache management
ggerganov Sep 18, 2023
0161372
parallel : example for serving multiple users in parallel
ggerganov Sep 18, 2023
466b513
parallel : disable hot-plug to avoid cache fragmentation
ggerganov Sep 18, 2023
897cacc
fixes : speculative KV cache + llama worst-case graph
ggerganov Sep 18, 2023
fa0e677
llama : extend batch API to select which logits to output
ggerganov Sep 18, 2023
daf4c6d
llama : fix worst case graph build
ggerganov Sep 19, 2023
7e2b997
ggml-cuda : update rope implementation for parallel decoding (#3254)
slaren Sep 19, 2023
25bd254
make : add parallel to build + fix static functions in llama.cpp
ggerganov Sep 19, 2023
467e307
simple : fix token counting
ggerganov Sep 19, 2023
36714e1
parallel : various improvements
ggerganov Sep 19, 2023
ddad227
llama : fix cell_max logic + rename functions
ggerganov Sep 19, 2023
806d397
parallel : try smaller batches when the KV cache is fragmented
ggerganov Sep 19, 2023
16090a5
parallel : fix sequence termination criteria
ggerganov Sep 19, 2023
d37081a
llama : silence errors KV cache errors
ggerganov Sep 19, 2023
82e20e9
parallel : remove new line from prompt
ggerganov Sep 19, 2023
4b5f3cd
parallel : process system prompt once + configurable paramters + llam…
ggerganov Sep 19, 2023
8a9aca3
parallel : remove question with short answers
ggerganov Sep 19, 2023
eed3fd4
parallel : count cache misses
ggerganov Sep 19, 2023
6028879
parallel : print misses on each request
ggerganov Sep 19, 2023
7b7472e
parallel : minor
ggerganov Sep 19, 2023
e1067ef
llama : fix n_kv to never become 0
ggerganov Sep 20, 2023
a1327c7
parallel : rename hot-plug to continuous-batching
ggerganov Sep 20, 2023
addae65
llama : improve llama_batch API + simplify parallel example
ggerganov Sep 20, 2023
b377bf2
simple : add parallel decoding support
ggerganov Sep 20, 2023
db0fc2d
simple : improve comments + free batch
ggerganov Sep 20, 2023
e04dc51
ggml-cuda : add rope f16, restore performance with parallel decoding …
slaren Sep 20, 2023
5420696
llama : disable MPI for now
ggerganov Sep 20, 2023
2f3a46f
train : make KQ_pos memory buffer permanent via dummy scale op
ggerganov Sep 20, 2023
1be2b8c
ggml : revert change to ggml_cpy, add ggml_cont_Nd instead (#3275)
slaren Sep 20, 2023
ee1d670
parallel : fix bug (extra BOS) + smaller token_prev array
ggerganov Sep 20, 2023
ded9b43
parallel : fix cases where the input prompts can overflow the batch
ggerganov Sep 20, 2023
b2debf6
parallel : add disabled experimental batch chunking in powers of two
ggerganov Sep 20, 2023
5a3369d
llama : llama.h formatting + comments
ggerganov Sep 21, 2023
8845160
simple : add README.md
ggerganov Sep 21, 2023
c1596f6
llama : fix kv cache heuristic when context is less than 32
ggerganov Sep 27, 2023
2585690
Merge branch 'master' into custom-attention-mask
ggerganov Sep 28, 2023
4ad0676
parallel : fix crash when `-n -1`
ggerganov Sep 28, 2023
e946379
llama : simplify returns if/else branches
ggerganov Sep 28, 2023
4c72ab1
metal : use mm kernels for batch size > 2
ggerganov Sep 28, 2023
d008733
examples : utilize new llama_get_logits_ith()
ggerganov Sep 28, 2023
a207561
examples : add example for batched decoding
ggerganov Sep 28, 2023
2b8830a
examples : do not eval prompt 2 times (close #3348)
ggerganov Sep 28, 2023
ce2d995
server : clear the KV cache beyond n_past before llama_decode
ggerganov Sep 28, 2023
c5650ed
server : avoid context swaps by shifting the KV cache
ggerganov Sep 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
parallel : disable hot-plug to avoid cache fragmentation
  • Loading branch information
ggerganov committed Sep 18, 2023
commit 466b513851ff8ec73889ce6414b8a15d570f77c7
91 changes: 60 additions & 31 deletions examples/parallel/parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ static std::string trim(const std::string & str) {
}

static std::string k_system = R"(
Transcript of a dialog, where the User interacts with an Assistant.
Transcript of a never ending dialog, where the User interacts with an Assistant.
The Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.

User: Hello, what is the temperature outside?
Expand Down Expand Up @@ -59,6 +59,9 @@ struct client {

llama_token sampled;

int64_t t_start_prompt;
int64_t t_start_gen;

int32_t n_prompt = 0;
int32_t n_decoded = 0;
int32_t i_batch = -1;
Expand Down Expand Up @@ -133,33 +136,47 @@ int main(int argc, char ** argv) {

for (auto & client : clients) {
if (client.seq_id == -1) {
client.seq_id = g_seq_id;
client.input = k_prompts[rand() % k_prompts.size()];
client.prompt = k_system + client.input + "\nAssistant:";
client.response = "";
std::fill(client.last_tokens.begin(), client.last_tokens.end(), 0);

std::vector<llama_token> prompt_tokens;
prompt_tokens = ::llama_tokenize(ctx, client.prompt, true);

for (size_t i = 0; i < prompt_tokens.size(); ++i) {
batch_token.push_back(prompt_tokens[i]);
batch_pos.push_back(i);
batch_seq_id.push_back(client.seq_id);
batch_clients.push_back(&client);
continue;
}

batch_token.push_back(client.sampled);
batch_pos.push_back(client.n_decoded);
batch_seq_id.push_back(client.seq_id);
batch_clients.push_back(&client);
client.n_decoded += 1;
client.i_batch = batch_token.size() - 1;
}

if (batch_token.empty()) {
// all sequences have ended - clear the entire KV cache
llama_kv_cache_rm_tokens(ctx, -1, -1);

for (auto & client : clients) {
if (client.seq_id == -1) {
client.seq_id = g_seq_id;
client.t_start_prompt = ggml_time_us();
client.t_start_gen = 0;

client.input = k_prompts[rand() % k_prompts.size()];
client.prompt = k_system + client.input + "\nAssistant:";
client.response = "";
std::fill(client.last_tokens.begin(), client.last_tokens.end(), 0);

std::vector<llama_token> prompt_tokens;
prompt_tokens = ::llama_tokenize(ctx, client.prompt, true);

for (size_t i = 0; i < prompt_tokens.size(); ++i) {
batch_token.push_back(prompt_tokens[i]);
batch_pos.push_back(i);
batch_seq_id.push_back(client.seq_id);
batch_clients.push_back(&client);
}
client.n_prompt = prompt_tokens.size();
client.n_decoded = prompt_tokens.size();
client.i_batch = batch_token.size() - 1;

g_seq_id += 1;
}
client.n_prompt = prompt_tokens.size();
client.n_decoded = prompt_tokens.size();
client.i_batch = batch_token.size() - 1;

g_seq_id += 1;
} else {
batch_token.push_back(client.sampled);
batch_pos.push_back(client.n_decoded);
batch_seq_id.push_back(client.seq_id);
batch_clients.push_back(&client);
client.n_decoded += 1;
client.i_batch = batch_token.size() - 1;
}
}

Expand Down Expand Up @@ -188,6 +205,10 @@ int main(int argc, char ** argv) {

const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.last_tokens, candidates, client.i_batch - i);

if (client.t_start_gen == 0) {
client.t_start_gen = ggml_time_us();
}

// remember which tokens were sampled - used for repetition penalties during sampling
client.last_tokens.erase(client.last_tokens.begin());
client.last_tokens.push_back(id);
Expand All @@ -199,7 +220,10 @@ int main(int argc, char ** argv) {
//printf("client %d, seq %d, token %d, pos %d, batch %d: %s\n",
// client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str());

if (id == llama_token_eos(ctx) || client.n_decoded > params.n_predict || client.response.find("User:") != std::string::npos) {
if (id == llama_token_eos(ctx) || client.n_decoded > params.n_predict ||
client.response.find("User:") != std::string::npos ||
client.response.find('\n') != std::string::npos) {
// basic reverse prompt
const size_t pos = client.response.find("User:");
if (pos != std::string::npos) {
client.response = client.response.substr(0, pos);
Expand All @@ -211,13 +235,18 @@ int main(int argc, char ** argv) {

n_tokens_total += client.n_decoded - client.n_prompt;

printf("\033[1mClient %d, seq %d, prompt %d t, response %d t, speed: %.2f t/s\033[0m: \n\nInput: %s\nResponse: %s\n\n",
printf("\033[1mClient %2d, seq %4d, prompt %4d t, response %4d t, speed: PP %5.2f t/s, TG %5.2f, AVG %5.2f \033[0m: \n\nInput: %s\nResponse: %s\n\n",
client.id, client.seq_id, client.n_prompt, client.n_decoded - client.n_prompt,
(double) n_tokens_total / (t_main_end - t_main_start) * 1e6,
8000 client.input.c_str(), ::trim(client.response).c_str());
(double) (client.n_prompt ) / (client.t_start_gen - client.t_start_prompt) * 1e6,
(double) (client.n_decoded - client.n_prompt) / (t_main_end - client.t_start_gen) * 1e6,
(double) (client.n_decoded ) / (t_main_end - client.t_start_prompt) * 1e6,
::trim(client.input).c_str(),
::trim(client.response).c_str());

client.seq_id = -1;
}

client.i_batch = -1;
}
}

Expand Down
4 changes: 4 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2606,6 +2606,8 @@ static struct ggml_cgraph * llm_build_llama(
const int32_t n_tokens = batch.n_tokens;
const int32_t n_kv = llama_kv_cache_cell_max(kv_self);

//printf("n_kv = %d\n", n_kv);

const bool do_rope_shift = kv_self.has_shift || ggml_allocr_is_measure(lctx.alloc);

auto & buf_compute = lctx.buf_compute;
Expand Down Expand Up @@ -4052,6 +4054,8 @@ static bool llama_eval_internal(
batch.seq_id = seq_id.data();
}

kv_self.head = 0;

if (!llama_kv_cache_find_slot(kv_self, batch)) {
return false;
}
Expand Down
0