@@ -1596,13 +1596,15 @@ def prepare_messages_for_inference(
1596
1596
function_call = (
1597
1597
tool_choice if isinstance (tool_choice , str ) else tool_choice ["function" ]
1598
1598
)
1599
+ else :
1600
+ function_call = "auto"
1599
1601
1600
1602
prompt = prepare_messages_for_inference (
1601
1603
messages , tokenizer , version , functions , tools
1602
1604
)
1603
1605
1604
1606
# If no tools/functions are provided
1605
- if function_call is None and ( functions is None or len (functions ) == 0 ) :
1607
+ if function_call == "none" or functions is None or len (functions ) == 0 :
1606
1608
if version == "v1" :
1607
1609
stop = END_ASSISTANT_TOKEN
1608
1610
else :
@@ -1630,6 +1632,7 @@ def prepare_messages_for_inference(
1630
1632
logits_processor = logits_processor ,
1631
1633
grammar = grammar ,
1632
1634
)
1635
+ completion_or_completion_chunks ["choices" ][0 ]["text" ] = completion_or_completion_chunks ["choices" ][0 ]["text" ].lstrip ()
1633
1636
return _convert_completion_to_chat (completion_or_completion_chunks , stream = stream ) # type: ignore
1634
1637
1635
1638
assert stream is False # TODO: support stream mode
@@ -1692,13 +1695,12 @@ def create_completion(stop):
1692
1695
1693
1696
return completion
1694
1697
1698
+ content = ""
1695
1699
function_calls , function_bodies = [], []
1696
1700
1697
1701
if version == "v1" :
1698
1702
# If no or "auto" tool_choice/function_call
1699
- if function_call is None or (
1700
- isinstance (function_call , str ) and function_call == "auto"
1701
- ):
1703
+ if isinstance (function_call , str ) and function_call == "auto" :
1702
1704
stops = ["\n " , END_ASSISTANT_TOKEN ]
1703
1705
# If tool_choice/function_call is "none"
1704
1706
elif isinstance (function_call , str ) and function_call == "none" :
@@ -1747,70 +1749,67 @@ def create_completion(stop):
1747
1749
else :
1748
1750
function_bodies .append (completion_text .strip ())
1749
1751
else :
1750
- # Loop until all parallel function calls are generated
1751
- while True :
1752
- # If no or "auto" tool_choice/function_call
1753
- if function_call is None or (
1754
- isinstance (function_call , str ) and function_call == "auto"
1755
- ):
1756
- grammar = None
1757
- stops = CONTENT_TOKEN
1758
- # If tool_choice/function_call is "none"
1759
- elif isinstance (function_call , str ) and function_call == "none" :
1760
- prompt = (
1761
- prepare_messages_for_inference (messages , tokenizer , version , [], [])
1762
- + "all\n <|content|>"
1763
- )
1764
- stops = STOP_TOKEN
1765
- # If tool_choice/function_call is provided
1766
- elif isinstance (function_call , dict ):
1767
- prompt += f"{ function_call ['name' ]} \n { CONTENT_TOKEN } "
1768
- stops = STOP_TOKEN
1769
- function_call = function_call ["name" ]
1770
- function_calls .append (function_call )
1771
- grammar = get_grammar (function_call )
1772
- else :
1773
- prompt = prompt
1774
- stops = STOP_TOKEN
1775
-
1752
+ # If tool_choice/function_call is "none"
1753
+ if isinstance (function_call , str ) and function_call == "none" :
1754
+ prompt = (
1755
+ prepare_messages_for_inference (messages , tokenizer , version , [], [])
1756
+ + "all\n <|content|>"
1757
+ )
1758
+ stops = [STOP_TOKEN , FROM_TOKEN ]
1759
+ completion = create_completion (stop = stops )
1760
+ completion ["choices" ][0 ]["text" ] = completion ["choices" ][0 ]["text" ].strip ()
1761
+ return _convert_completion_to_chat (completion , stream = stream ) # type: ignore
1762
+ # If tool_choice/function_call is provided
1763
+ elif isinstance (function_call , dict ):
1764
+ prompt += f"{ function_call ['name' ]} \n { CONTENT_TOKEN } "
1765
+ function_call = function_call ["name" ]
1766
+ function_calls .append (function_call )
1767
+ grammar = get_grammar (function_call )
1768
+ stops = [STOP_TOKEN , FROM_TOKEN ]
1776
1769
completion = create_completion (stop = stops )
1777
1770
completion_text = completion ["choices" ][0 ]["text" ]
1778
-
1779
- # If the generation does not involve a function call
1780
- if prompt .endswith ("all\n <|content|>" ) and not completion_text .startswith (
1781
- "all"
1782
- ):
1783
- return _convert_completion_to_chat (completion , stream = stream ) # type: ignore
1784
- # Generate model response if the model decides not to call any function
1785
- elif prompt .endswith (RECIPIENT_TOKEN ) and completion_text .startswith ("all" ):
1786
- prompt += completion_text + CONTENT_TOKEN
1787
- completion = create_completion (stop = STOP_TOKEN )
1788
- return _convert_completion_to_chat (completion , stream = stream ) # type: ignore
1789
- # Generate parameters if model decides to call a function
1790
- elif prompt .endswith (RECIPIENT_TOKEN ):
1791
- function_calls .append (completion_text [:- 1 ])
1792
- grammar = get_grammar (function_calls [- 1 ])
1793
- completion = create_completion (stop = [STOP_TOKEN , "\n " ])
1794
- function_bodies .append (completion ["choices" ][0 ]["text" ].strip ())
1795
- prompt += f"{ function_calls [- 1 ]} \n { CONTENT_TOKEN } { function_bodies [- 1 ]} "
1771
+ function_bodies .append (completion_text .strip ())
1772
+ # If "auto" or no tool_choice/function_call
1773
+ elif isinstance (function_call , str ) and function_call == "auto" :
1774
+ while True :
1775
+ # Generate function name first
1796
1776
grammar = None
1797
-
1798
- # Try to generate the beginning of next turn
1799
- # If empty completion, break from loop
1800
- next_turn_completion_text = create_completion (
1801
- stop = [STOP_TOKEN , RECIPIENT_TOKEN ]
1802
- )["choices" ][0 ]["text" ]
1803
- if len (next_turn_completion_text ) > 0 :
1804
- prompt += f"\n { FROM_TOKEN } assistant\n { RECIPIENT_TOKEN } "
1777
+ stops = CONTENT_TOKEN
1778
+ completion = create_completion (stop = stops )
1779
+ completion_text = completion ["choices" ][0 ]["text" ]
1780
+ function_name = completion_text .strip ()
1781
+ if function_name == "all" :
1782
+ prompt += "all\n <|content|>"
1805
1783
else :
1806
- break
1807
- # Break from loop if tool_choice/function_call is provided as a dict
1808
- else :
1809
- function_bodies .append (completion_text .strip ())
1810
- break
1784
+ function_call = completion_text .strip ()
1785
+ prompt += f"{ function_call } \n <|content|>"
1786
+ function_calls .append (function_call )
1787
+ grammar = get_grammar (function_call )
1788
+ # Generate content
1789
+ stops = [RECIPIENT_TOKEN , STOP_TOKEN ]
1790
+ completion = create_completion (stop = stops )
1791
+ completion_text = completion ["choices" ][0 ]["text" ]
1792
+ if function_name == "all" :
1793
+ content += completion_text .removesuffix ("\n <|from|>assistant\n " ).removesuffix ("\n <|from|> assistant\n " )
1794
+ content = content .lstrip ()
1795
+ # Check whether the model wants to generate another turn
1796
+ if "<|from|> assistant" in completion_text or "<|from|>assistant" in completion_text :
1797
+ cleaned_completion_text = completion_text .removesuffix ("\n <|from|>assistant\n " ).removesuffix ("\n <|from|> assistant\n " ).strip ()
1798
+ prompt += f"{ cleaned_completion_text } \n <|from|>assistant\n <|recipient|>"
1799
+ else :
1800
+ break
1801
+ else :
1802
+ function_bodies .append (completion_text .strip ())
1803
+ # Check whether the model wants to generate another turn
1804
+ prompt += completion_text .strip ()
1805
+ grammar = None
1806
+ completion = create_completion (stop = stops )
1807
+ if "<|from|> assistant" in completion ["choices" ][0 ]["text" ] or "<|from|>assistant" in completion ["choices" ][0 ]["text" ]:
1808
+ prompt += "\n <|from|>assistant\n <|recipient|>"
1809
+ else :
1810
+ break
1811
1811
1812
1812
assert "usage" in completion
1813
- assert len (function_calls ) > 0
1814
1813
assert len (function_calls ) == len (function_bodies )
1815
1814
1816
1815
tool_calls = []
@@ -1843,14 +1842,14 @@ def create_completion(stop):
1843
1842
"index" : 0 ,
1844
1843
"message" : {
1845
1844
"role" : "assistant" ,
1846
- "content" : None ,
1845
+ "content" : None if content == "" else content ,
1847
1846
"function_call" : {
1848
1847
"name" : tool_calls [0 ]["function" ]["name" ],
1849
1848
"arguments" : tool_calls [0 ]["function" ]["arguments" ],
1850
- },
1851
- "tool_calls" : tool_calls ,
1849
+ } if len ( tool_calls ) > 0 else None ,
1850
+ "tool_calls" : tool_calls if len ( tool_calls ) > 0 else None ,
1852
1851
},
1853
- "finish_reason" : "tool_calls" ,
1852
+ "finish_reason" : "tool_calls" if len ( tool_calls ) > 0 else "stop" ,
1854
1853
}
1855
1854
],
1856
1855
usage = completion ["usage" ],
0 commit comments