8000 Merge branch 'main' of https://github.com/abetlen/llama-cpp-python in… · asusevski/llama-cpp-python@7d4a5ec · GitHub
[go: up one dir, main page]

Skip to content
8000
10000

Commit 7d4a5ec

Browse files
committed
Merge branch 'main' of https://github.com/abetlen/llama-cpp-python into main
2 parents bf64752 + 8a60c7b commit 7d4a5ec

File tree

1 file changed

+65
-66
lines changed

1 file changed

+65
-66
lines changed

llama_cpp/llama_chat_format.py

Lines changed: 65 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1596,13 +1596,15 @@ def prepare_messages_for_inference(
15961596
function_call = (
15971597
tool_choice if isinstance(tool_choice, str) else tool_choice["function"]
15981598
)
1599+
else:
1600+
function_call = "auto"
15991601

16001602
prompt = prepare_messages_for_inference(
16011603
messages, tokenizer, version, functions, tools
16021604
)
16031605

16041606
# If no tools/functions are provided
1605-
if function_call is None and (functions is None or len(functions) == 0):
1607+
if function_call == "none" or functions is None or len(functions) == 0:
16061608
if version == "v1":
16071609
stop = END_ASSISTANT_TOKEN
16081610
else:
@@ -1630,6 +1632,7 @@ def prepare_messages_for_inference(
16301632
logits_processor=logits_processor,
16311633
grammar=grammar,
16321634
)
1635+
completion_or_completion_chunks["choices"][0]["text"] = completion_or_completion_chunks["choices"][0]["text"].lstrip()
16331636
return _convert_completion_to_chat(completion_or_completion_chunks, stream=stream) # type: ignore
16341637

16351638
assert stream is False # TODO: support stream mode
@@ -1692,13 +1695,12 @@ def create_completion(stop):
16921695

16931696
return completion
16941697

1698+
content = ""
16951699
function_calls, function_bodies = [], []
16961700

16971701
if version == "v1":
16981702
# If no or "auto" tool_choice/function_call
1699-
if function_call is None or (
1700-
isinstance(function_call, str) and function_call == "auto"
1701-
):
1703+
if isinstance(function_call, str) and function_call == "auto":
17021704
stops = ["\n", END_ASSISTANT_TOKEN]
17031705
# If tool_choice/function_call is "none"
17041706
elif isinstance(function_call, str) and function_call == "none":
@@ -1747,70 +1749,67 @@ def create_completion(stop):
17471749
else:
17481750
function_bodies.append(completion_text.strip())
17491751
else:
1750-
# Loop until all parallel function calls are generated
1751-
while True:
1752-
# If no or "auto" tool_choice/function_call
1753-
if function_call is None or (
1754-
isinstance(function_call, str) and function_call == "auto"
1755-
):
1756-
grammar = None
1757-
stops = CONTENT_TOKEN
1758-
# If tool_choice/function_call is "none"
1759-
elif isinstance(function_call, str) and function_call == "none":
1760-
prompt = (
1761-
prepare_messages_for_inference(messages, tokenizer, version, [], [])
1762-
+ "all\n<|content|>"
1763-
)
1764-
stops = STOP_TOKEN
1765-
# If tool_choice/function_call is provided
1766-
elif isinstance(function_call, dict):
1767-
prompt += f"{function_call['name']}\n{CONTENT_TOKEN}"
1768-
stops = STOP_TOKEN
1769-
function_call = function_call["name"]
1770-
function_calls.append(function_call)
1771-
grammar = get_grammar(function_call)
1772-
else:
1773-
prompt = prompt
1774-
stops = STOP_TOKEN
1775-
1752+
# If tool_choice/function_call is "none"
1753+
if isinstance(function_call, str) and function_call == "none":
1754+
prompt = (
1755+
prepare_messages_for_inference(messages, tokenizer, version, [], [])
1756+
+ "all\n<|content|>"
1757+
)
1758+
stops = [STOP_TOKEN, FROM_TOKEN]
1759+
completion = create_completion(stop=stops)
1760+
completion["choices"][0]["text"] = completion["choices"][0]["text"].strip()
1761+
return _convert_completion_to_chat(completion, stream=stream) # type: ignore
1762+
# If tool_choice/function_call is provided
1763+
elif isinstance(function_call, dict):
1764+
prompt += f"{function_call['name']}\n{CONTENT_TOKEN}"
1765+
function_call = function_call["name"]
1766+
function_calls.append(function_call)
1767+
grammar = get_grammar(function_call)
1768+
stops = [STOP_TOKEN, FROM_TOKEN]
17761769
completion = create_completion(stop=stops)
17771770
completion_text = completion["choices"][0]["text"]
1778-
1779-
# If the generation does not involve a function call
1780-
if prompt.endswith("all\n<|content|>") and not completion_text.startswith(
1781-
"all"
1782-
):
1783-
return _convert_completion_to_chat(completion, stream=stream) # type: ignore
1784-
# Generate model response if the model decides not to call any function
1785-
elif prompt.endswith(RECIPIENT_TOKEN) and completion_text.startswith("all"):
1786-
prompt += completion_text + CONTENT_TOKEN
1787-
completion = create_completion(stop=STOP_TOKEN)
1788-
return _convert_completion_to_chat(completion, stream=stream) # type: ignore
1789-
# Generate parameters if model decides to call a function
1790-
elif prompt.endswith(RECIPIENT_TOKEN):
1791-
function_calls.append(completion_text[:-1])
1792-
grammar = get_grammar(function_calls[-1])
1793-
completion = create_completion(stop=[STOP_TOKEN, "\n"])
1794-
function_bodies.append(completion["choices"][0]["text"].strip())
1795-
prompt += f"{function_calls[-1]}\n{CONTENT_TOKEN}{function_bodies[-1]}"
1771+
function_bodies.append(completion_text.strip())
1772+
# If "auto" or no tool_choice/function_call
1773+
elif isinstance(function_call, str) and function_call == "auto":
1774+
while True:
1775+
# Generate function name first
17961776
grammar = None
1797-
1798-
# Try to generate the beginning of next turn
1799-
# If empty completion, break from loop
1800-
next_turn_completion_text = create_completion(
1801-
stop=[STOP_TOKEN, RECIPIENT_TOKEN]
1802-
)["choices"][0]["text"]
1803-
if len(next_turn_completion_text) > 0:
1804-
prompt += f"\n{FROM_TOKEN}assistant\n{RECIPIENT_TOKEN}"
1777+
stops = CONTENT_TOKEN
1778+
completion = create_completion(stop=stops)
1779+
completion_text = completion["choices"][0]["text"]
1780+
function_name = completion_text.strip()
1781+
if function_name == "all":
1782+
prompt += "all\n<|content|>"
18051783
else:
1806-
break
1807-
# Break from loop if tool_choice/function_call is provided as a dict
1808-
else:
1809-
function_bodies.append(completion_text.strip())
1810-
break
1784+
function_call = completion_text.strip()
1785+
prompt += f"{function_call}\n<|content|>"
1786+
function_calls.append(function_call)
1787+
grammar = get_grammar(function_call)
1788+
# Generate content
1789+
stops = [RECIPIENT_TOKEN, STOP_TOKEN]
1790+
completion = create_completion(stop=stops)
1791+
completion_text = completion["choices"][0]["text"]
1792+
if function_name == "all":
1793+
content += completion_text.removesuffix("\n<|from|>assistant\n").removesuffix("\n<|from|> assistant\n")
1794+
content = content.lstrip()
1795+
# Check whether the model wants to generate another turn
1796+
if "<|from|> assistant" in completion_text or "<|from|>assistant" in completion_text:
1797+
cleaned_completion_text = completion_text.removesuffix("\n<|from|>assistant\n").removesuffix("\n<|from|> assistant\n").strip()
1798+
prompt += f"{cleaned_completion_text}\n<|from|>assistant\n<|recipient|>"
1799+
else:
1800+
break
1801+
else:
1802+
function_bodies.append(completion_text.strip())
1803+
# Check whether the model wants to generate another turn
1804+
prompt += completion_text.strip()
1805+
grammar = None
1806+
completion = create_completion(stop=stops)
1807+
if "<|from|> assistant" in completion["choices"][0]["text"] or "<|from|>assistant" in completion["choices"][0]["text"]:
1808+
prompt += "\n<|from|>assistant\n<|recipient|>"
1809+
else:
1810+
break
18111811

18121812
assert "usage" in completion
1813-
assert len(function_calls) > 0
18141813
assert len(function_calls) == len(function_bodies)
18151814

18161815
tool_calls = []
@@ -1843,14 +1842,14 @@ def create_completion(stop):
18431842
"index": 0,
18441843
"message": {
18451844
"role": "assistant",
1846-
"content": None,
1845+
"content": None if content == "" else content,
18471846
"function_call": {
18481847
"name": tool_calls[0]["function"]["name"],
18491848
"arguments": tool_calls[0]["function"]["arguments"],
1850-
},
1851-
"tool_calls": tool_calls,
1849+
} if len(tool_calls) > 0 else None,
1850+
"tool_calls": tool_calls if len(tool_calls) > 0 else None,
18521851
},
1853-
"finish_reason": "tool_calls",
1852+
"finish_reason": "tool_calls" if len(tool_calls) > 0 else "stop",
18541853
}
18551854
],
18561855
usage=completion["usage"],

0 commit comments

Comments
 (0)
0