8000 Bugfix: missing response_format for functionary and llava chat handlers · LOGp/llama-cpp-python@b62c449 · GitHub
[go: up one dir, main page]

Skip to content

Commit b62c449

Browse files
committed
Bugfix: missing response_format for functionary and llava chat handlers
1 parent 80f4162 commit b62c449

File tree

1 file changed

+30
-12
lines changed

1 file changed

+30
-12
lines changed

llama_cpp/llama_chat_format.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -318,10 +318,11 @@ def basic_create_chat_completion(
318318
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
319319
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
320320
stop = stop + rstop
321-
321+
322322
if response_format is not None and response_format["type"] == "json_object":
323-
print("hello world")
324-
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
323+
grammar = llama_grammar.LlamaGrammar.from_string(
324+
llama_grammar.JSON_GBNF
325+
)
325326

326327
completion_or_chunks = llama.create_completion(
327328
prompt=prompt,
@@ -577,6 +578,7 @@ def functionary_chat_handler(
577578
top_k: int = 40,
578579
stream: bool = False,
579580
stop: Optional[Union[str, List[str]]] = [],
581+
response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None,
580582
max_tokens: int = 256,
581583
presence_penalty: float = 0.0,
582584
frequency_penalty: float = 0.0,
@@ -753,6 +755,10 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
753755
assert isinstance(function_call, str)
754756
assert stream is False # TODO: support stream mode
755757

758+
if response_format is not None and response_format["type"] == "json_object":
759+
with suppress_stdout_stderr(disable=llama.verbose):
760+
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
761+
756762
return llama_types.CreateChatCompletionResponse(
757763
id="chat" + completion["id"],
758764
object="chat.completion",
@@ -785,11 +791,11 @@ def __init__(self, clip_model_path: str, verbose: bool = False):
785791
self._llava_cpp = llava_cpp
786792
self.clip_model_path = clip_model_path
787793
self.verbose = verbose
788-
self._clip_free = self._llava_cpp._libllava.clip_free # type: ignore
794+
self._clip_free = self._llava_cpp._libllava.clip_free # type: ignore
789795

790796
with suppress_stdout_stderr(disable=self.verbose):
791797
self.clip_ctx = self._llava_cpp.clip_model_load(
792-
self.clip_model_path.encode(), 0
798+
self.clip_model_path.encode(), 0
793799
)
794800

795801
def __del__(self):
@@ -825,6 +831,9 @@ def __call__(
825831
top_k: int = 40,
826832
stream: bool = False,
827833
stop: Optional[Union[str, List[str]]] = [],
834+
response_format: Optional[
835+
llama_types.ChatCompletionRequestResponseFormat
836+
] = None,
828837
max_tokens: int = 256,
829838
presence_penalty: float = 0.0,
830839
frequency_penalty: float = 0.0,
@@ -851,7 +860,6 @@ def __call__(
851860
if system_prompt != ""
852861
else "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
853862
)
854-
system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
855863
user_role = "\nUSER:"
856864
assistant_role = "\nASSISTANT:"
857865
llama.reset()
@@ -890,11 +898,13 @@ def __call__(
890898
ctypes.c_ubyte * len(data_array)
891899
).from_buffer(data_array)
892900
with suppress_stdout_stderr(disable=self.verbose):
893-
embed = self._llava_cpp.llava_image_embed_make_with_bytes(
894-
ctx_clip=self.clip_ctx,
895-
n_threads=llama.context_params.n_threads,
896-
image_bytes=c_ubyte_ptr,
897-
image_bytes_length=len(image_bytes),
901+
embed = (
902+
self._llava_cpp.llava_image_embed_make_with_bytes(
903+
ctx_clip=self.clip_ctx,
904+
n_threads=llama.context_params.n_threads,
905+
image_bytes=c_ubyte_ptr,
906+
image_bytes_length=len(image_bytes),
907+
)
898908
)
899909
try:
900910
n_past = ctypes.c_int(llama.n_tokens)
@@ -917,9 +927,17 @@ def __call__(
917927
f"ASSISTANT: {message['content']}".encode("utf8"), add_bos=False
918928
)
919929
)
930+
assert llama.n_ctx() >= llama.n_tokens
920931
llama.eval(llama.tokenize(f"{assistant_role}".encode("utf8"), add_bos=False))
932+
assert llama.n_ctx() >= llama.n_tokens
921933

922-
prompt = llama.input_ids[:llama.n_tokens].tolist()
934+
prompt = llama.input_ids[: llama.n_tokens].tolist()
935+
936+
if response_format is not None and response_format["type"] == "json_object":
937+
with suppress_stdout_stderr(disable=self.verbose):
938+
grammar = llama_grammar.LlamaGrammar.from_string(
939+
llama_grammar.JSON_GBNF
940+
)
923941

924942
return _convert_completion_to_chat(
925943
llama.create_completion(

0 commit comments

Comments
 (0)
0