55import dataclasses
66from typing import Any , Dict , Iterator , List , Optional , Tuple , Union , Protocol
77
8- import llama_cpp .llama_types as llama_types
98import llama_cpp .llama as llama
9+ import llama_cpp .llama_types as llama_types
10+ import llama_cpp .llama_grammar as llama_grammar
1011
1112
1213class LlamaChatCompletionHandler (Protocol ):
@@ -25,6 +26,9 @@ def __call__(
2526 stream : bool = False ,
2627 stop : Optional [Union [str , List [str ]]] = [],
2728 seed : Optional [int ] = None ,
29+ response_format : Optional [
30+ llama_types .ChatCompletionRequestResponseFormat
31+ ] = None ,
2832 max_tokens : int = 256 ,
2933 presence_penalty : float = 0.0 ,
3034 frequency_penalty : float = 0.0 ,
@@ -37,7 +41,10 @@ def __call__(
3741 logits_processor : Optional [llama .LogitsProcessorList ] = None ,
3842 grammar : Optional [llama .LlamaGrammar ] = None ,
3943 ** kwargs , # type: ignore
40- ) -> Union [llama_types .CreateChatCompletionResponse , Iterator [llama_types .CreateChatCompletionStreamResponse ]]:
44+ ) -> Union [
45+ llama_types .CreateChatCompletionResponse ,
46+ Iterator [llama_types .CreateChatCompletionStreamResponse ],
47+ ]:
4148 ...
4249
4350
@@ -169,6 +176,7 @@ class ChatFormatterResponse:
169176class ChatFormatter (Protocol ):
170177 def __call__ (
171178 self ,
179+ * ,
172180 messages : List [llama_types .ChatCompletionRequestMessage ],
173181 ** kwargs : Any ,
174182 ) -> ChatFormatterResponse :
@@ -264,17 +272,24 @@ def _convert_completion_to_chat(
264272def register_chat_format (name : str ):
265273 def decorator (f : ChatFormatter ):
266274 def basic_create_chat_completion (
275+ * ,
267276 llama : llama .Llama ,
268277 messages : List [llama_types .ChatCompletionRequestMessage ],
269278 functions : Optional [List [llama_types .ChatCompletionFunction ]] = None ,
270279 function_call : Optional [
271- Union [ str , llama_types .ChatCompletionFunctionCall ]
280+ llama_types .ChatCompletionRequestFunctionCall
272281 ] = None ,
282+ tools : Optional [List [llama_types .ChatCompletionTool ]] = None ,
283+ tool_choice : Optional [llama_types .ChatCompletionToolChoiceOption ] = None ,
273284 temperature : float = 0.2 ,
274285 top_p : float = 0.95 ,
275286 top_k : int = 40 ,
276287 stream : bool = False ,
277288 stop : Optional [Union [str , List [str ]]] = [],
289+ seed : Optional [int ] = None ,
290+ response_format : Optional [
291+ llama_types .ChatCompletionRequestResponseFormat
292+ ] = None ,
278293 max_tokens : int = 256 ,
279294 presence_penalty : float = 0.0 ,
280295 frequency_penalty : float = 0.0 ,
@@ -286,8 +301,10 @@ def basic_create_chat_completion(
286301 model : Optional [str ] = None ,
287302 logits_processor : Optional [llama .LogitsProcessorList ] = None ,
288303 grammar : Optional [llama .LlamaGrammar ] = None ,
304+ ** kwargs , # type: ignore
289305 ) -> Union [
290- llama_types .ChatCompletion , Iterator [llama_types .ChatCompletionChunk ]
306+ llama_types .CreateChatCompletionResponse ,
307+ Iterator [llama_types .CreateChatCompletionStreamResponse ],
291308 ]:
292309
67DE
code> result = f (
293310 messages = messages ,
@@ -299,6 +316,10 @@ def basic_create_chat_completion(
299316 stop = [] if stop is None else [stop ] if isinstance (stop , str ) else stop
300317 rstop = result .stop if isinstance (result .stop , list ) else [result .stop ]
301318 stop = stop + rstop
319+
320+ if response_format is not None and response_format ["type" ] == "json_object" :
321+ print ("hello world" )
322+ grammar = llama_grammar .LlamaGrammar .from_string (llama_grammar .JSON_GBNF )
302323
303324 completion_or_chunks = llama .create_completion (
304325 prompt = prompt ,
@@ -307,6 +328,7 @@ def basic_create_chat_completion(
307328 top_k = top_k ,
308329 stream = stream ,
309330 stop = stop ,
331+ seed = seed ,
310332 max_tokens = max_tokens ,
311333 presence_penalty = presence_penalty ,
312334 frequency_penalty = frequency_penalty ,
@@ -319,7 +341,7 @@ def basic_create_chat_completion(
319341 logits_processor = logits_processor ,
320342 grammar = grammar ,
321343 )
322- return _convert_completion_to_chat (completion_or_chunks , stream = stream ) # type: ignore
344+ return _convert_completion_to_chat (completion_or_chunks , stream = stream )
323345
324346 register_chat_completion_handler (name )(basic_create_chat_completion )
325347 return f
@@ -727,7 +749,7 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
727749
728750 assert "usage" in completion
729751 assert isinstance (function_call , str )
730- assert stream is False # TODO: support stream mode
752+ assert stream is False # TODO: support stream mode
731753
732754 return llama_types .CreateChatCompletionResponse (
733755 id = "chat" + completion ["id" ],
@@ -759,7 +781,9 @@ def __init__(self, clip_model_path: str):
759781 self ._llava_cpp = llava_cpp
760782 self .clip_model_path = clip_model_path
761783
762- self .clip_ctx = self ._llava_cpp .clip_model_load (self .clip_model_path .encode (), 0 )
784+ self .clip_ctx = self ._llava_cpp .clip_model_load (
785+ self .clip_model_path .encode (), 0
786+ )
763787
764788 def __del__ (self ):
765789 if self .clip_ctx is not None :
@@ -805,64 +829,108 @@ def __call__(
805829 logits_processor : Optional [llama .LogitsProcessorList ] = None ,
806830 grammar : Optional [llama .LlamaGrammar ] = None ,
807831 ** kwargs , # type: ignore
808- ) -> Union [llama_types .CreateChatCompletionResponse , Iterator [llama_types .CreateChatCompletionStreamResponse ]]:
809- assert llama .context_params .logits_all is True # BUG: logits_all=True is required for llava
832+ ) -> Union [
833+ llama_types .CreateChatCompletionResponse ,
834+ Iterator [llama_types .CreateChatCompletionStreamResponse ],
835+ ]:
836+ assert (
837+ llama .context_params .logits_all is True
838+ ) # BUG: logits_all=True is required for llava
810839 assert self .clip_ctx is not None
811840 system_prompt = _get_system_message (messages )
812- system_prompt = system_prompt if system_prompt != "" else "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
813- system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
841+ system_prompt = (
842+ system_prompt
843+ if system_prompt != ""
844+ else "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
845+ )
846+ system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
814847 user_role = "\n USER:"
815848 assistant_role = "\n ASSISTANT:"
816849 llama .reset ()
817850 llama .eval (llama .tokenize (system_prompt .encode ("utf8" ), add_bos = True ))
818851 for message in messages :
819852 if message ["role" ] == "user" and message ["content" ] is not None :
820853 if isinstance (message ["content" ], str ):
821- llama .eval (llama .tokenize (f"{ user_role } { message ['content' ]} " .encode ("utf8" ), add_bos = False ))
854+ llama .eval (
855+ llama .tokenize (
856+ f"{ user_role } { message ['content' ]} " .encode ("utf8" ),
857+ add_bos = False ,
858+ )
859+ )
822860 else :
823861 assert isinstance (message ["content" ], list )
824- llama .eval (llama .tokenize (f"{ user_role } " .encode ("utf8" ), add_bos = False ))
862+ llama .eval (
863+ llama .tokenize (f"{ user_role } " .encode ("utf8" ), add_bos = False )
864+ )
825865 for content in message ["content" ]:
826866 if content ["type" ] == "text" :
827- llama .eval (llama .tokenize (f"{ content ['text' ]} " .encode ("utf8" ), add_bos = False ))
867+ llama .eval (
868+ llama .tokenize (
869+ f"{ content ['text' ]} " .encode ("utf8" ), add_bos = False
870+ )
871+ )
828872 if content ["type" ] == "image_url" :
829- image_bytes = self .load_image (content ["image_url" ]["url" ]) if isinstance (content ["image_url" ], dict ) else self .load_image (content ["image_url" ])
873+ image_bytes = (
874+ self .load_image (content ["image_url" ]["url" ])
875+ if isinstance (content ["image_url" ], dict )
876+ else self .load_image (content ["image_url" ])
877+ )
830878 import array
831- data_array = array .array ('B' , image_bytes )
832- c_ubyte_ptr = (ctypes .c_ubyte * len (data_array )).from_buffer (data_array )
833- embed = self ._llava_cpp .llava_image_embed_make_with_bytes (ctx_clip = self .clip_ctx , n_threads = llama .context_params .n_threads , image_bytes = c_ubyte_ptr , image_bytes_length = len (image_bytes ))
879+
880+ data_array = array .array ("B" , image_bytes )
881+ c_ubyte_ptr = (
882+ ctypes .c_ubyte * len (data_array )
883+ ).from_buffer (data_array )
884+ embed = self ._llava_cpp .llava_image_embed_make_with_bytes (
885+ ctx_clip = self .clip_ctx ,
886+ n_threads = llama .context_params .n_threads ,
887+ image_bytes = c_ubyte_ptr ,
888+ image_bytes_length = len (image_bytes ),
889+ )
834890 # image_bytes_p = (ctypes.c_uint8 * len(image_bytes)).from_buffer_copy(image_bytes)
835891 # embed = self._llava_cpp.llava_image_embed_make_with_bytes(ctx_clip=self.clip_ctx, n_threads=1, image_bytes=image_bytes_p, image_bytes_length=len(image_bytes))
836892 try :
837893 n_past = ctypes .c_int (llama .n_tokens )
838894 n_past_p = ctypes .pointer (n_past )
839- self ._llava_cpp .llava_eval_image_embed (ctx_llama = llama .ctx , embed = embed , n_batch = llama .n_batch , n_past = n_past_p )
895+ self ._llava_cpp .llava_eval_image_embed (
896+ ctx_llama = llama .ctx ,
897+ embed = embed ,
898+ n_batch = llama .n_batch ,
899+ n_past = n_past_p ,
900+ )
840901 assert llama .n_ctx () >= n_past .value
841902 llama .n_tokens = n_past .value
842903 finally :
843904 self ._llava_cpp .llava_image_embed_free (embed )
844905 if message ["role" ] == "assistant" and message ["content" ] is not None :
845- llama .eval (llama .tokenize (f"ASSISTANT: { message ['content' ]} " .encode ("utf8" ), add_bos = False ))
906+ llama .eval (
907+ llama .tokenize (
908+ f"ASSISTANT: { message ['content' ]} " .encode ("utf8" ), add_bos = False
909+ )
910+ )
846911 llama .eval (llama .tokenize (f"{ assistant_role } " .encode ("utf8" ), add_bos = False ))
847912
848913 prompt = llama ._input_ids .tolist ()
849914
850- return _convert_completion_to_chat (llama .create_completion (
851- prompt = prompt ,
852- temperature = temperature ,
853- top_p = top_p ,
854- top_k = top_k ,
915+ return _convert_completion_to_chat (
916+ llama .create_completion (
917+ prompt = prompt ,
918+ temperature = temperature ,
919+ top_p = top_p ,
920+ top_k = top_k ,
921+ stream = stream ,
922+ stop = stop ,
923+ max_tokens = max_tokens ,
924+ presence_penalty = presence_penalty ,
925+ frequency_penalty = frequency_penalty ,
926+ repeat_penalty = repeat_penalty ,
927+ tfs_z = tfs_z ,
928+ mirostat_mode = mirostat_mode ,
929+ mirostat_tau = mirostat_tau ,
930+ mirostat_eta = mirostat_eta ,
931+ model = model ,
932+ logits_processor = logits_processor ,
933+ grammar = grammar ,
934+ ),
855935 stream = stream ,
856- stop = stop ,
857- max_tokens = max_tokens ,
858- presence_penalty = presence_penalty ,
859- frequency_penalty = frequency_penalty ,
860- repeat_penalty = repeat_penalty ,
861- tfs_z = tfs_z ,
862- mirostat_mode = mirostat_mode ,
863- mirostat_tau = mirostat_tau ,
864- mirostat_eta = mirostat_eta ,
865- model = model ,
866- logits_processor = logits_processor ,
867- grammar = grammar ,
868- ), stream = stream )
936+ )
0 commit comments