8000 Implement code interpreter feature for functionary by jeffrey-fong · Pull Request #1433 · abetlen/llama-cpp-python · GitHub
[go: up one dir, main page]

Skip to content

Implement code interpreter feature for functionary #1433

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
J 10000 ump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions llama_cpp/llama_chat_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -1718,6 +1718,7 @@ def functionary_v1_v2_chat_handler(
**kwargs, # type: ignore
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]:
SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"""
PYTHON_SYSTEM_MESSAGE = """When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files."""

tokenizer = llama.tokenizer_
assert hasattr(
Expand Down Expand Up @@ -1860,9 +1861,14 @@ def prepare_messages_for_inference(
)
)

# Default system message
if version == "v2" and tools is not None and any([tool["type"] == "code_interpreter" for tool in tools]):
sys_msg = PYTHON_SYSTEM_MESSAGE
else:
sys_msg = SYSTEM_MESSAGE
all_messages.append(
llama_types.ChatCompletionRequestSystemMessage(
role="system", content=SYSTEM_MESSAGE
role="system", content=sys_msg
)
)

Expand Down Expand Up @@ -2124,7 +2130,7 @@ def generate_streaming(tools, functions, function_call, prompt):
)
else:
prompt += f"{function_name}\n<|content|>"
grammar = get_grammar(function_name)
grammar = get_grammar(function_name) if function_name != "python" else None
tool_id = "".join([random.choice(string.ascii_letters + string.digits) for _ in range(24)])
if tools is not None:
func_call_dict = {
Expand Down Expand Up @@ -2322,7 +2328,7 @@ def generate_streaming(tools, functions, function_call, prompt):
prompt = prompt
stops = ["\n", END_ASSISTANT_TOKEN]

completion = create_completion(stop=stops)
completion = create_completion(prompt=prompt, stop=stops, grammar=grammar)
completion_text = completion["choices"][0]["text"]
completion_tokens += completion["usage"]["completion_tokens"]

Expand All @@ -2349,7 +2355,7 @@ def generate_streaming(tools, functions, function_call, prompt):
completion_text.split(START_FUNCTION_CALL_TOKEN)[-1][:-1].strip()
)
grammar = get_grammar(function_calls[-1])
completion = create_completion(stop=END_FUNCTION_CALL_TOKEN)
completion = create_completion(prompt=prompt, stop=END_FUNCTION_CALL_TOKEN, grammar=grammar)
completion_tokens += completion["usage"]["completion_tokens"]
function_bodies.append(completion["choices"][0]["text"].strip())
# If the prompt involves a function call, just append generated parameters to function_bodies
Expand All @@ -2363,7 +2369,7 @@ def generate_streaming(tools, functions, function_call, prompt):
function_calls.append(function_call)
grammar = get_grammar(function_call)
stops = [STOP_TOKEN, FROM_TOKEN]
completion = create_completion(stop=stops)
completion = create_completion(prompt=prompt, stop=stops, grammar=grammar)
completion_text = completion["choices"][0]["text"]
completion_tokens += completion["usage"]["completion_tokens"]
function_bodies.append(completion_text.strip())
Expand All @@ -2373,20 +2379,22 @@ def generate_streaming(tools, functions, function_call, prompt):
# Generate function name first
grammar = None
stops = CONTENT_TOKEN
completion = create_completion(stop=stops)
completion = create_completion(prompt=prompt, stop=stops, grammar=grammar)
completion_text = completion["choices"][0]["text"]
completion_tokens += completion["usage"]["completion_tokens"]
function_name = completion_text.strip()
if function_name == "all":
prompt += "all\n<|content|>"
if function_name in ["all", "python"]:
prompt += f"{function_name}\n<|content|>"
if function_name == "python":
function_calls.append("python")
else:
function_call = completion_text.strip()
prompt += f"{function_call}\n<|content|>"
function_calls.append(function_call)
grammar = get_grammar(function_call)
# Generate content
stops = [RECIPIENT_TOKEN, STOP_TOKEN]
completion = create_completion(stop=stops)
completion = create_completion(prompt=prompt, stop=stops, grammar=grammar)
completion_text = completion["choices"][0]["text"]
completion_tokens += completion["usage"]["completion_tokens"]
if function_name == "all":
Expand All @@ -2413,7 +2421,7 @@ def generate_streaming(tools, functions, function_call, prompt):
# Check whether the model wants to generate another turn
prompt += completion_text.strip()
grammar = None
completion = create_completion(stop=stops)
completion = create_completion(prompt=prompt, stop=stops, grammar=grammar)
completion_tokens += completion["usage"]["completion_tokens"]
if "<|from|> assistant" in completion["choices"][0]["text"] or "<|from|>assistant" in completion["choices"][0]["text"]:
prompt += "\n<|from|>assistant\n<|recipient|>"
Expand Down
4 changes: 2 additions & 2 deletions llama_cpp/llama_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,8 @@ class ChatCompletionToolFunction(TypedDict):


class ChatCompletionTool(TypedDict):
type: Literal["function"]
function: ChatCompletionToolFunction
type: Literal["function", "code_interpreter"]
function: NotRequired[ChatCompletionToolFunction]


class ChatCompletionNamedToolChoiceFunction(TypedDict):
Expand Down
0