diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 7e9a6af23..ac431b56a 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -12,6 +12,7 @@ import warnings import contextlib import multiprocessing +import threading from typing import ( Any, @@ -544,6 +545,8 @@ def free_lora_adapter(): self._sampler = None + self.inference_stopped = threading.Event() + @property def ctx(self) -> llama_cpp.llama_context_p: return self._ctx.ctx @@ -907,10 +910,26 @@ def generate( sample_idx = self.n_tokens + len(tokens) - 1 tokens = list(tokens) + self.inference_stopped.clear() + # Eval and sample while True: + if hasattr(self, "abort_callback") and callable(self.abort_callback): + + if self.abort_callback(): + print("Aborting outer loop from callback.") + #self.inference_stopped.set() + return + self.eval(tokens) while sample_idx < self.n_tokens: + + if hasattr(self, "abort_callback") and callable(self.abort_callback): + if self.abort_callback(): + print("Aborting generation from Python callback.") + #self.inference_stopped.set() + return # Exit generation cleanly + token = self.sample( top_k=top_k, top_p=top_p, @@ -1364,6 +1383,11 @@ def logit_bias_processor( finish_reason = "stop" break + if hasattr(self, "abort_callback") and callable(self.abort_callback): + if self.abort_callback(): + finish_reason = "abort" + break + if stream: remaining_tokens = completion_tokens[returned_tokens:] remaining_text = self.detokenize( @@ -1510,7 +1534,7 @@ def logit_bias_processor( finish_reason = "length" break - if stopping_criteria is not None and stopping_criteria( + if len(self._scores) > 0 and finish_reason is not "abort" and stopping_criteria is not None and stopping_criteria( self._input_ids, self._scores[-1, :] ): text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)