From 7e13b5f18e41a6c21ba6b6525f288902950faff3 Mon Sep 17 00:00:00 2001 From: jeffrey-fong Date: Mon, 6 May 2024 19:28:31 +0800 Subject: [PATCH] implement code interpreter feature for functionary --- llama_cpp/llama_chat_format.py | 28 ++++++++++++++++++---------- llama_cpp/llama_types.py | 4 ++-- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 3ab94e0d3..490d0850d 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -1718,6 +1718,7 @@ def functionary_v1_v2_chat_handler( **kwargs, # type: ignore ) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]: SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary""" + PYTHON_SYSTEM_MESSAGE = """When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files.""" tokenizer = llama.tokenizer_ assert hasattr( @@ -1860,9 +1861,14 @@ def prepare_messages_for_inference( ) ) + # Default system message + if version == "v2" and tools is not None and any([tool["type"] == "code_interpreter" for tool in tools]): + sys_msg = PYTHON_SYSTEM_MESSAGE + else: + sys_msg = SYSTEM_MESSAGE all_messages.append( llama_types.ChatCompletionRequestSystemMessage( - role="system", content=SYSTEM_MESSAGE + role="system", content=sys_msg ) ) @@ -2124,7 +2130,7 @@ def generate_streaming(tools, functions, function_call, prompt): ) else: prompt += f"{function_name}\n<|content|>" - grammar = get_grammar(function_name) + grammar = get_grammar(function_name) if function_name != "python" else None tool_id = "".join([random.choice(string.ascii_letters + string.digits) for _ in range(24)]) if tools is not None: func_call_dict = { @@ -2322,7 +2328,7 @@ def generate_streaming(tools, functions, function_call, prompt): prompt = prompt stops = ["\n", END_ASSISTANT_TOKEN] - completion = create_completion(stop=stops) + completion = create_completion(prompt=prompt, stop=stops, grammar=grammar) completion_text = completion["choices"][0]["text"] completion_tokens += completion["usage"]["completion_tokens"] @@ -2349,7 +2355,7 @@ def generate_streaming(tools, functions, function_call, prompt): completion_text.split(START_FUNCTION_CALL_TOKEN)[-1][:-1].strip() ) grammar = get_grammar(function_calls[-1]) - completion = create_completion(stop=END_FUNCTION_CALL_TOKEN) + completion = create_completion(prompt=prompt, stop=END_FUNCTION_CALL_TOKEN, grammar=grammar) completion_tokens += completion["usage"]["completion_tokens"] function_bodies.append(completion["choices"][0]["text"].strip()) # If the prompt involves a function call, just append generated parameters to function_bodies @@ -2363,7 +2369,7 @@ def generate_streaming(tools, functions, function_call, prompt): function_calls.append(function_call) grammar = get_grammar(function_call) stops = [STOP_TOKEN, FROM_TOKEN] - completion = create_completion(stop=stops) + completion = create_completion(prompt=prompt, stop=stops, grammar=grammar) completion_text = completion["choices"][0]["text"] completion_tokens += completion["usage"]["completion_tokens"] function_bodies.append(completion_text.strip()) @@ -2373,12 +2379,14 @@ def generate_streaming(tools, functions, function_call, prompt): # Generate function name first grammar = None stops = CONTENT_TOKEN - completion = create_completion(stop=stops) + completion = create_completion(prompt=prompt, stop=stops, grammar=grammar) completion_text = completion["choices"][0]["text"] completion_tokens += completion["usage"]["completion_tokens"] function_name = completion_text.strip() - if function_name == "all": - prompt += "all\n<|content|>" + if function_name in ["all", "python"]: + prompt += f"{function_name}\n<|content|>" + if function_name == "python": + function_calls.append("python") else: function_call = completion_text.strip() prompt += f"{function_call}\n<|content|>" @@ -2386,7 +2394,7 @@ def generate_streaming(tools, functions, function_call, prompt): grammar = get_grammar(function_call) # Generate content stops = [RECIPIENT_TOKEN, STOP_TOKEN] - completion = create_completion(stop=stops) + completion = create_completion(prompt=prompt, stop=stops, grammar=grammar) completion_text = completion["choices"][0]["text"] completion_tokens += completion["usage"]["completion_tokens"] if function_name == "all": @@ -2413,7 +2421,7 @@ def generate_streaming(tools, functions, function_call, prompt): # Check whether the model wants to generate another turn prompt += completion_text.strip() grammar = None - completion = create_completion(stop=stops) + completion = create_completion(prompt=prompt, stop=stops, grammar=grammar) completion_tokens += completion["usage"]["completion_tokens"] if "<|from|> assistant" in completion["choices"][0]["text"] or "<|from|>assistant" in completion["choices"][0]["text"]: prompt += "\n<|from|>assistant\n<|recipient|>" diff --git a/llama_cpp/llama_types.py b/llama_cpp/llama_types.py index 4677785ae..5b5afd924 100644 --- a/llama_cpp/llama_types.py +++ b/llama_cpp/llama_types.py @@ -258,8 +258,8 @@ class ChatCompletionToolFunction(TypedDict): class ChatCompletionTool(TypedDict): - type: Literal["function"] - function: ChatCompletionToolFunction + type: Literal["function", "code_interpreter"] + function: NotRequired[ChatCompletionToolFunction] class ChatCompletionNamedToolChoiceFunction(TypedDict):