8000 Add abort callback · Pints-AI/llama.cpp@f2770b8 · GitHub
[go: up one dir, main page]

Skip to content

Commit f2770b8

Browse files
committed
Add abort callback
1 parent c1ac54b commit f2770b8

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

llama.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -386,14 +386,18 @@ struct LLM_TN {
386386
// ggml helpers
387387
//
388388

389-
static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
389+
static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads, llama_abort_callback abort_callback) {
390390
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
391391

392392
if (plan.work_size > 0) {
393393
buf.resize(plan.work_size);
394394
plan.work_data = buf.data();
395395
}
396396

397+
if (abort_callback) {
398+
plan.abort_callback = abort_callback;
399+
}
400+
397401
ggml_graph_compute(graph, &plan);
398402
}
399403

@@ -2902,10 +2906,10 @@ static bool llama_eval_internal(
29022906
ggml_metal_get_tensor(lctx.ctx_metal, embeddings);
29032907
}
29042908
} else {
2905-
ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads);
2909+
ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads, nullptr);
29062910
}
29072911
#else
2908-
ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads);
2912+
ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads, nullptr);
29092913
#endif
29102914

29112915
#if GGML_USE_MPI
@@ -5198,7 +5202,7 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const
51985202

51995203
struct ggml_cgraph gf = ggml_build_forward(r);
52005204

5201-
ggml_graph_compute_helper(work_buffer, &gf, n_threads);
5205+
ggml_graph_compute_helper(work_buffer, &gf, n_threads, nullptr);
52025206

52035207
// we won't need these tensors again, reset the context to save memory
52045208
ggml_free(lora_ctx);
@@ -5240,6 +5244,8 @@ struct llama_context_params llama_context_default_params() {
52405244
/*.rope_freq_scale =*/ 1.0f,
52415245
/*.progress_callback =*/ nullptr,
52425246
/*.progress_callback_user_data =*/ nullptr,
5247+
/*.abort_callback =*/ nullptr,
5248+
/*.abort_callback_user_data =*/ nullptr,
52435249
/*.low_vram =*/ false,
52445250
/*.mul_mat_q =*/ false,
52455251
/*.f16_kv =*/ true,
@@ -5776,7 +5782,7 @@ void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_conte
57765782

57775783
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d));
57785784
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d));
5779-
ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1);
5785+
ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1, nullptr);
57805786

57815787
ggml_free(cpy_ctx);
57825788

@@ -5886,7 +5892,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
58865892

58875893
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d));
58885894
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d));
5889-
ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1);
5895+
ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1, nullptr);
58905896

58915897
ggml_free(cpy_ctx);
58925898
}

llama.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ extern "C" {
121121

122122
typedef void (*llama_progress_callback)(float progress, void *ctx);
123123

124+
typedef bool (*llama_abort_callback)(void *ctx);
125+
124126
struct llama_context_params {
125127
uint32_t seed; // RNG seed, -1 for random
126128
int32_t n_ctx; // text context
@@ -139,6 +141,11 @@ extern "C" {
139141
// context pointer passed to the progress callback
140142
void * progress_callback_user_data;
141143

144+
// called during llama_eval() to check if the evaluation should be aborted
145+
llama_abort_callback abort_callback;
146+
// context pointer passed to the abort callback
147+
void * abort_callback_user_data;
148+
142149
// Keep the booleans together to avoid misalignment during copy-by-value.
143150
bool low_vram; // if true, reduce VRAM usage at the cost of performance
144151
bool mul_mat_q; // if true, use experimental mul_mat_q kernels

0 commit comments

Comments
 (0)
0