8000 Fix and optimize functionary chat handler by jeffrey-fong · Pull Request #1282 · abetlen/llama-cpp-python · GitHub
[go: up one dir, main page]

Skip to content

Fix and optimize functionary chat handler #1282

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 65 additions & 66 deletions llama_cpp/llama_chat_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -1596,13 +1596,15 @@ def prepare_messages_for_inference(
function_call = (
tool_choice if isinstance(tool_choice, str) else tool_choice["function"]
)
else:
function_call = "auto"

prompt = prepare_messages_for_inference(
messages, tokenizer, version, functions, tools
)

# If no tools/functions are provided
if function_call is None and (functions is None or len(functions) == 0):
if function_call == "none" or functions is None or len(functions) == 0:
if version == "v1":
stop = END_ASSISTANT_TOKEN
else:
Expand Down Expand Up @@ -1630,6 +1632,7 @@ def prepare_messages_for_inference(
logits_processor=logits_processor,
grammar=grammar,
)
completion_or_completion_chunks["choices"][0]["text"] = completion_or_completion_chunks["choices"][0]["text"].lstrip()
return _convert_completion_to_chat(completion_or_completion_chunks, stream=stream) # type: ignore

assert stream is False # TODO: support stream mode
Expand Down Expand Up @@ -1692,13 +1695,12 @@ def create_completion(stop):

return completion

content = ""
function_calls, function_bodies = [], []

if version == "v1":
# If no or "auto" tool_choice/function_call
if function_call is None or (
isinstance(function_call, str) and function_call == "auto"
):
if isinstance(function_call, str) and function_call == "auto":
stops = ["\n", END_ASSISTANT_TOKEN]
# If tool_choice/function_call is "none"
elif isinstance(function_call, str) and function_call == "none":
Expand Down Expand Up @@ -1747,70 +1749,67 @@ def create_completion(stop):
else:
function_bodies.append(completion_text.strip())
else:
# Loop until all parallel function calls are generated
while True:
# If no or "auto" tool_choice/function_call
if function_call is None or (
isinstance(function_call, str) and function_call == "auto"
):
grammar = None
stops = CONTENT_TOKEN
# If tool_choice/function_call is "none"
elif isinstance(function_call, str) and function_call == "none":
prompt = (
prepare_messages_for_inference(messages, tokenizer, version, [], [])
+ "all\n<|content|>"
)
stops = STOP_TOKEN
# If tool_choice/function_call is provided
elif isinstance(function_call, dict):
prompt += f"{function_call['name']}\n{CONTENT_TOKEN}"
stops = STOP_TOKEN
function_call = function_call["name"]
function_calls.append(function_call)
grammar = get_grammar(function_call)
else:
prompt = prompt
stops = STOP_TOKEN

# If tool_choice/function_call is "none"
if isinstance(function_call, str) and function_call == "none":
prompt = (
prepare_messages_for_inference(messages, tokenizer, version, [], [])
+ "all\n<|content|>"
)
stops = [STOP_TOKEN, FROM_TOKEN]
completion = create_completion(stop=stops)
completion["choices"][0]["text"] = completion["choices"][0]["text"].strip()
return _convert_completion_to_chat(completion, stream=stream) # type: ignore
# If tool_choice/function_call is provided
elif isinstance(function_call, dict):
prompt += f"{function_call['name']}\n{CONTENT_TOKEN}"
function_call = function_call["name"]
function_calls.append(function_call)
grammar = get_grammar(function_call)
stops = [STOP_TOKEN, FROM_TOKEN]
completion = create_completion(stop=stops)
completion_text = completion["choices"][0]["text"]

# If the generation does not involve a function call
if prompt.endswith("all\n<|content|>") and not completion_text.startswith(
"all"
):
return _convert_completion_to_chat(completion, stream=stream) # type: ignore
# Generate model response if the model decides not to call any function
elif prompt.endswith(RECIPIENT_TOKEN) and completion_text.startswith("all"):
prompt += completion_text + CONTENT_TOKEN
completion = create_completion(stop=STOP_TOKEN)
return _convert_completion_to_chat(completion, stream=stream) # type: ignore
# Generate parameters if model decides to call a function
elif prompt.endswith(RECIPIENT_TOKEN):
function_calls.append(completion_text[:-1])
grammar = get_grammar(function_calls[-1])
completion = create_completion(stop=[STOP_TOKEN, "\n"])
function_bodies.append(completion["choices"][0]["text"].strip())
prompt += f"{function_calls[-1]}\n{CONTENT_TOKEN}{function_bodies[-1]}"
function_bodies.append(completion_text.strip())
# If "auto" or no tool_choice/function_call
elif isinstance(function_call, str) and function_call == "auto":
while True:
# Generate function name first
grammar = None

# Try to generate the beginning of next turn
# If empty completion, break from loop
next_turn_completion_text = create_completion(
stop=[STOP_TOKEN, RECIPIENT_TOKEN]
)["choices"][0]["text"]
if len(next_turn_completion_text) > 0:
prompt += f"\n{FROM_TOKEN}assistant\n{RECIPIENT_TOKEN}"
stops = CONTENT_TOKEN
completion = create_completion(stop=stops)
completion_text = completion["choices"][0]["text"]
function_name = completion_text.strip()
if function_name == "all":
prompt += "all\n<|content|>"
else:
break
# Break from loop if tool_choice/function_call is provided as a dict
else:
function_bodies.append(completion_text.strip())
break
function_call = completion_text.strip()
prompt += f"{function_call}\n<|content|>"
function_calls.append(function_call)
grammar = get_grammar(function_call)
# Generate content
stops = [RECIPIENT_TOKEN, STOP_TOKEN]
completion = create_completion(stop=stops)
completion_text = completion["choices"][0]["text"]
if function_name == "all":
content += completion_text.removesuffix("\n<|from|>assistant\n").removesuffix("\n<|from|> assistant\n")
content = content.lstrip()
# Check whether the model wants to generate another turn
if "<|from|> assistant" in completion_text or "<|from|>assistant" in completion_text:
cleaned_completion_text = completion_text.removesuffix("\n<|from|>assistant\n").removesuffix("\n<|from|> assistant\n").strip()
prompt += f"{cleaned_completion_text}\n<|from|>assistant\n<|recipient|>"
A540 else:
break
else:
function_bodies.append(completion_text.strip())
# Check whether the model wants to generate another turn
prompt += completion_text.strip()
grammar = None
completion = create_completion(stop=stops)
if "<|from|> assistant" in completion["choices"][0]["text"] or "<|from|>assistant" in completion["choices"][0]["text"]:
prompt += "\n<|from|>assistant\n<|recipient|>"
else:
break

assert "usage" in completion
assert len(function_calls) > 0
assert len(function_calls) == len(function_bodies)

tool_calls = []
Expand Down Expand Up @@ -1843,14 +1842,14 @@ def create_completion(stop):
"index": 0,
"message": {
"role": "assistant",
"content": None,
"content": None if content == "" else content,
"function_call": {
"name": tool_calls[0]["function"]["name"],
"arguments": tool_calls[0]["function"]["arguments"],
},
"tool_calls": tool_calls,
} if len(tool_calls) > 0 else None,
"tool_calls": tool_calls if len(tool_calls) > 0 else None,
},
"finish_reason": "tool_calls",
"finish_reason": "tool_calls" if len(tool_calls) > 0 else "stop",
}
],
usage=completion["usage"],
Expand Down
0