8000 kv-cache : refactor + add llama_memory_state_i by ggerganov · Pull Request #13746 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

kv-cache : refactor + add llama_memory_state_i #13746

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
May 31, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
llama : handle aborts and compute errors
ggml-ci
  • Loading branch information
ggerganov committed May 30, 2025
commit 780bba94d84995b6607830a07f72cf742f16a032
8000
2 changes: 2 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -677,12 +677,14 @@ extern "C" {

// Returns the smallest position present in the KV cache for the specified sequence
// This is typically non-zero only for SWA caches
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
// Return -1 if the sequence is empty
LLAMA_API llama_pos llama_kv_self_seq_pos_min(
struct llama_context * ctx,
llama_seq_id seq_id);

// Returns the largest position present in the KV cache for the specified sequence
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
// Return -1 if the sequence is empty
LLAMA_API llama_pos llama_kv_self_seq_pos_max(
struct llama_context * ctx,
Expand Down
111 changes: 78 additions & 33 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
#include "llama-model.h"
#include "llama-kv-cache.h"

#include <cinttypes>
#include <cstring>
#include <limits>
#include <stdexcept>
#include <cinttypes>

//
// llama_context
Expand Down Expand Up @@ -632,6 +633,49 @@ bool llama_context::apply_adapter_cvec(
return cvec.apply(model, data, len, n_embd, il_start, il_end);
}

llm_graph_result_ptr llama_context::process(const llama_ubatch & ubatch, llm_graph_type gtype, ggml_status * ret) {
auto * gf = graph_init();
if (!gf) {
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
if (ret) {
*ret = GGML_STATUS_FAILED;
}
return nullptr;
}

auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype);
if (!res) {
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
if (ret) {
*ret = GGML_STATUS_FAILED;
}
return nullptr;
}

// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);

if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
if (ret) {
*ret = GGML_STATUS_ALLOC_FAILED;
}
return nullptr;
}

res->set_inputs(&ubatch);

const auto status = graph_compute(gf, ubatch.n_tokens > 1);
if (status != GGML_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
if (ret) {
*ret = status;
}
return nullptr;
}

return res;
}

int llama_context::encode(llama_batch & inp_batch) {
if (inp_batch.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
Expand Down Expand Up @@ -703,26 +747,18 @@ int llama_context::encode(llama_batch & inp_batch) {
// ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
cparams.causal_attn = false;

auto * gf = graph_init();
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER);

ggml_backend_sched_alloc_graph(sched.get(), gf);

res->set_inputs(&ubatch);
ggml_status status;
auto res = process(ubatch, LLM_GRAPH_TYPE_ENCODER, &status);

cparams.causal_attn = causal_attn_org;

const auto compute_status = graph_compute(gf, n_tokens > 1);
switch (compute_status) {
case GGML_STATUS_SUCCESS:
break;
case GGML_STATUS_ABORTED:
return 2;
case GGML_STATUS_ALLOC_FAILED:
return -2;
case GGML_STATUS_FAILED:
default:
return -3;
if (!res) {
switch (status) {
case GGML_STATUS_ABORTED: return 2;
case GGML_STATUS_ALLOC_FAILED: return -2;
case GGML_STATUS_FAILED: return -3;
case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
}
}

auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
Expand Down Expand Up @@ -942,25 +978,34 @@ int llama_context::decode(llama_batch & inp_batch) {
ggml_backend_sched_reset(sched.get());
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);

auto * gf = graph_init();
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DECODER);
ggml_status status;
auto res = process(ubatch, LLM_GRAPH_TYPE_DECODER, &status);

if (!res) {
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES] = { std::numeric_limits<llama_pos>::max() };

for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
const auto & seq_id = ubatch.seq_id[i][0];

// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
}

ggml_backend_sched_alloc_graph(sched.get(), gf);
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
continue;
}

res->set_inputs(&ubatch);
LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);

llama_kv_self_seq_rm(this, s, pos_min[s], -1);
}

const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
if (compute_status != GGML_STATUS_SUCCESS) {
switch (compute_status) {
case GGML_STATUS_ABORTED:
return 2;
case GGML_STATUS_ALLOC_FAILED:
return -2;
case GGML_STATUS_FAILED:
default:
return -3;
switch (status) {
case GGML_STATUS_ABORTED: return 2;
case GGML_STATUS_ALLOC_FAILED: return -2;
case GGML_STATUS_FAILED: return -3;
case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
}
}

Expand Down
12 changes: 9 additions & 3 deletions src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ struct llama_context {
int32_t il_start,
int32_t il_end);

// process a single ubatch with a specific graph type
// ret contains the status of the graph computation
// returns nullptr only if ret != GGML_STATUS_SUCCESS
llm_graph_result_ptr process(
const llama_ubatch & ubatch,
llm_graph_type gtype,
ggml_status * ret);

int encode(llama_batch & inp_batch);
int decode(llama_batch & inp_batch);

Expand Down Expand Up @@ -181,9 +189,7 @@ struct llama_context {
ggml_cgraph * graph_init();

// returns the result of ggml_backend_sched_graph_compute_async execution
ggml_status graph_compute(
ggml_cgraph * gf,
bool batched);
ggml_status graph_compute(ggml_cgraph * gf, bool batched);

// reserve a graph with a dummy ubatch of the specified size
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs);
Expand Down
0