From de5e629acf134dd948b94a2dafdd5b6f39cb4627 Mon Sep 17 00:00:00 2001 From: Dmitry Ivanov Date: Wed, 14 May 2025 23:38:50 +0300 Subject: [PATCH 1/2] added abort_callback --- llama_cpp/llama.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 7e9a6af23..90666b93e 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -909,8 +909,19 @@ def generate( # Eval and sample while True: + if hasattr(self, "abort_callback") and callable(self.abort_callback): + if self.abort_callback(): + print("Aborting outer loop from callback.") + return + self.eval(tokens) while sample_idx < self.n_tokens: + + if hasattr(self, "abort_callback") and callable(self.abort_callback): + if self.abort_callback(): + print("Aborting generation from Python callback.") + return # Exit generation cleanly + token = self.sample( top_k=top_k, top_p=top_p, From 0e3109aecdc1101fecb1d9f517c69eacb4ec6feb Mon Sep 17 00:00:00 2001 From: Dmitry Ivanov Date: Sun, 25 May 2025 23:26:35 +0300 Subject: [PATCH 2/2] abort_callback also stops create_completion loop --- llama_cpp/llama.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 90666b93e..ac431b56a 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -12,6 +12,7 @@ import warnings import contextlib import multiprocessing +import threading from typing import ( Any, @@ -544,6 +545,8 @@ def free_lora_adapter(): self._sampler = None + self.inference_stopped = threading.Event() + @property def ctx(self) -> llama_cpp.llama_context_p: return self._ctx.ctx @@ -907,11 +910,15 @@ def generate( sample_idx = self.n_tokens + len(tokens) - 1 tokens = list(tokens) + self.inference_stopped.clear() + # Eval and sample while True: if hasattr(self, "abort_callback") and callable(self.abort_callback): + if self.abort_callback(): print("Aborting outer loop from callback.") + #self.inference_stopped.set() return self.eval(tokens) @@ -920,6 +927,7 @@ def generate( if hasattr(self, "abort_callback") and callable(self.abort_callback): if self.abort_callback(): print("Aborting generation from Python callback.") + #self.inference_stopped.set() return # Exit generation cleanly token = self.sample( @@ -1375,6 +1383,11 @@ def logit_bias_processor( finish_reason = "stop" break + if hasattr(self, "abort_callback") and callable(self.abort_callback): + if self.abort_callback(): + finish_reason = "abort" + break + if stream: remaining_tokens = completion_tokens[returned_tokens:] remaining_text = self.detokenize( @@ -1521,7 +1534,7 @@ def logit_bias_processor( finish_reason = "length" break - if stopping_criteria is not None and stopping_criteria( + if len(self._scores) > 0 and finish_reason is not "abort" and stopping_criteria is not None and stopping_criteria( self._input_ids, self._scores[-1, :] ): text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)