10000 context : fix worst-case reserve outputs (#12545) · ggml-org/llama.cpp@2d77d88 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2d77d88

Browse files
authored
context : fix worst-case reserve outputs (#12545)
ggml-ci
1 parent c95fa36 commit 2d77d88

File tree

1 file changed

+21
-4
lines changed
  • src

1 file changed

+21
-4
lines changed

src/llama-context.cpp

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -294,10 +294,7 @@ llama_context::llama_context(
294294
// TODO: something cleaner
295295
const auto n_outputs_save = n_outputs;
296296

297-
// max number of outputs
298-
n_outputs = n_tokens;
299-
300-
LLAMA_LOG_DEBUG("%s: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
297+
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
301298

302299
int n_splits_pp = -1;
303300
int n_nodes_pp = -1;
@@ -313,8 +310,15 @@ llama_context::llama_context(
313310
// reserve pp graph first so that buffers are only allocated once
314311
{
315312
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
313+
314+
// max number of outputs
315+
n_outputs = ubatch_pp.n_tokens;
316+
317+
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
318+
316319
auto * gf = graph_init();
317320
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
321+
318322
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
319323
throw std::runtime_error("failed to allocate compute pp buffers");
320324
}
@@ -326,20 +330,33 @@ llama_context::llama_context(
326330
// reserve with tg graph to get the number of splits and nodes
327331
{
328332
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
333+
334+
n_outputs = ubatch_tg.n_tokens;
335+
336+
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs);
337+
329338
auto * gf = graph_init();
330339
graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
340+
331341
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
332342
throw std::runtime_error("failed to allocate compute tg buffers");
333343
}
344+
334345
n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
335346
n_nodes_tg = ggml_graph_n_nodes(gf);
336347
}
337348

338349
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
339350
{
340351
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
352+
353+
n_outputs = ubatch_pp.n_tokens;
354+
355+
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
356+
341357
auto * gf = graph_init();
342358
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
359+
343360
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
344361
throw std::runtime_error("failed to allocate compute pp buffers");
345362
}

0 commit comments

Comments
 (0)
0