@@ -294,10 +294,7 @@ llama_context::llama_context(
294
294
// TODO: something cleaner
295
295
const auto n_outputs_save = n_outputs;
296
296
297
- // max number of outputs
298
- n_outputs = n_tokens;
299
-
300
- LLAMA_LOG_DEBUG (" %s: n_tokens = %d, n_seqs = %d, n_outputs = %d\n " , __func__, n_tokens, n_seqs, n_outputs);
297
+ LLAMA_LOG_DEBUG (" %s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n " , __func__, n_tokens, n_seqs, n_outputs);
301
298
302
299
int n_splits_pp = -1 ;
303
300
int n_nodes_pp = -1 ;
@@ -313,8 +310,15 @@ llama_context::llama_context(
313
310
// reserve pp graph first so that buffers are only allocated once
314
311
{
315
312
llama_ubatch ubatch_pp = { true , n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
313
+
314
+ // max number of outputs
315
+ n_outputs = ubatch_pp.n_tokens ;
316
+
317
+ LLAMA_LOG_DEBUG (" %s: reserving graph for n_tokens = %d, n_seqs = %d\n " , __func__, ubatch_pp.n_tokens , ubatch_pp.n_seqs );
318
+
316
319
auto * gf = graph_init ();
317
320
graph_build (ctx_compute.get (), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
321
+
318
322
if (!ggml_backend_sched_reserve (sched.get (), gf)) {
319
323
throw std::runtime_error (" failed to allocate compute pp buffers" );
320
324
}
@@ -326,20 +330,33 @@ llama_context::llama_context(
326
330
// reserve with tg graph to get the number of splits and nodes
327
331
{
328
332
llama_ubatch ubatch_tg = { true , 1 , 1 , n_seqs, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
333
+
334
+ n_outputs = ubatch_tg.n_tokens ;
335
+
336
+ LLAMA_LOG_DEBUG (" %s: reserving graph for n_tokens = %d, n_seqs = %d\n " , __func__, ubatch_tg.n_tokens , ubatch_tg.n_seqs );
337
+
329
338
auto * gf = graph_init ();
330
339
graph_build (ctx_compute.get (), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
340
+
331
341
if (!ggml_backend_sched_reserve (sched.get (), gf)) {
332
342
throw std::runtime_error (" failed to allocate compute tg buffers" );
333
343
}
344
+
334
345
n_splits_tg = ggml_backend_sched_get_n_splits (sched.get ());
335
346
n_nodes_tg = ggml_graph_n_nodes (gf);
336
347
}
337
348
338
349
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
339
350
{
340
351
llama_ubatch ubatch_pp = { true , n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
352
+
353
+ n_outputs = ubatch_pp.n_tokens ;
354
+
355
+ LLAMA_LOG_DEBUG (" %s: reserving graph for n_tokens = %d, n_seqs = %d\n " , __func__, ubatch_pp.n_tokens , ubatch_pp.n_seqs );
356
+
341
357
auto * gf = graph_init ();
342
358
graph_build (ctx_compute.get (), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
359
+
343
360
if (!ggml_backend_sched_reserve (sched.get (), gf)) {
344
361
throw std::runtime_error (" failed to allocate compute pp buffers" );
345
362
}
0 commit comments