8000 server : separate the notion of position and KV tokens, remove prompt truncation by ngxson · Pull Request #13576 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

server : separate the notion of position and KV tokens, remove prompt truncation #13576

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
no more scan loop in n_kv_tokens()
  • Loading branch information
ngxson committed May 16, 2025
commit 678d7b1569a4b0fcb15bf8d65ad9779a9bb77e9f
10 changes: 6 additions & 4 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2099,10 +2099,11 @@ struct server_context {
}

// length of the Longest Common Subsequence between the current slot's prompt and the input prompt
int cur_lcs_len = slot.cache_tokens.get_common_prefix(task.prompt_tokens);
auto common_pos = slot.cache_tokens.get_common_prefix(task.prompt_tokens);
int cur_lcs_len = common_pos.first; // position, not tokens

// fraction of the common subsequence length compared to the current slot's prompt length
float cur_similarity = static_cast<float>(cur_lcs_len) / static_cast<int>(slot.cache_tokens.n_kv_tokens());
float cur_similarity = static_cast<float>(cur_lcs_len) / static_cast<int>(slot.cache_tokens.n_pos());

// select the current slot if the criteria match
if (cur_lcs_len > lcs_len && cur_similarity > slot_prompt_similarity) {
Expand Down Expand Up @@ -3094,8 +3095,9 @@ struct server_context {

if (slot.params.cache_prompt) {
// reuse any previously computed tokens that are common with the new prompt
slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens);
slot.n_kv_tokens = slot.cache_tokens.n_kv_tokens(slot.n_past);
auto common_pos = slot.cache_tokens.get_common_prefix(prompt_tokens);
slot.n_past = common_pos.first;
slot.n_kv_tokens = common_pos.second;

// reuse chunks from the cached prompt by shifting their KV cache in the new position
if (params_base.n_cache_reuse > 0) {
Expand Down
35 changes: 10 additions & 25 deletions tools/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1162,26 +1162,8 @@ struct server_tokens {
tokens[pos] = id;
}

// if end_pos == -1, we count all positions
size_t n_kv_tokens(llama_pos end_pos = -1) const {
if (end_pos == -1) {
return n_kv;
} else {
size_t res = 0;
for (llama_pos i = 0; i < end_pos;) {
auto & t = tokens[i];
if (t == LLAMA_TOKEN_NULL) {
auto & chunk = find_chunk(i);
auto img_tokens = mtmd_input_chunk_get_tokens_image(chunk.get());
res += mtmd_image_tokens_get_n_tokens(img_tokens);
i += mtmd_image_tokens_get_n_pos(img_tokens);
} else {
res++;
i++;
}
}
return res;
}
size_t n_kv_tokens() const {
return n_kv;
}

llama_pos n_pos() const {
Expand Down Expand Up @@ -1239,9 +1221,10 @@ struct server_tokens {
return common_detokenize(ctx, text_tokens, special);
}

// returns the position of the first token that is different
size_t get_common_prefix(const server_tokens & b) const {
// returns pair of <position, n_kv_tokens>
std::pair<llama_pos, size_t> get_common_prefix(const server_tokens & b) const {
size_t max_idx = std::min(tokens.size(), b.tokens.size());
size_t n_tok = 0;
for (size_t i = 0; i < max_idx; ++i) {
auto & ai = tokens[i];
auto & bi = b.tokens[i];
Expand All @@ -1260,17 +1243,19 @@ struct server_tokens {
if (ai_id == bi_id && a_pos == b_pos) {
GGML_ASSERT(a_pos > 0 && "Invalid image token"); // should never happen
i += a_pos - 1; // will be +1 by the for loop
n_tok += mtmd_image_tokens_get_n_tokens(a_img);
continue;
} else {
return i;
return {i, n_tok};
}
} else if (ai == bi) {
n_tok++;
continue;
} else {
return i;
return {i, n_tok};
}
}
return max_idx; // all tokens are equal
return {max_idx, n_tok}; // all tokens are equal
}

// make sure all text tokens are within the vocab range
Expand Down
Loading
0