@@ -2316,6 +2316,8 @@ def base_function_calling(
2316
2316
grammar : Optional [llama .LlamaGrammar ] = None ,
2317
2317
logprobs : Optional [bool ] = None ,
2318
2318
top_logprobs : Optional [int ] = None ,
2319
+ role_prefix : Optional [str ] = "" ,
2320
+ role_suffix : Optional [str ] = "" ,
2319
2321
** kwargs , # type: ignore
2320
2322
) -> Union [
2321
2323
llama_types .CreateChatCompletionResponse ,
@@ -2377,7 +2379,7 @@ def base_function_calling(
2377
2379
min_p = min_p ,
2378
2380
typical_p = typical_p ,
2379
2381
stream = stream ,
2380
- stop = stop ,
2382
+ stop = stop + [ "</done>" ] ,
2381
2383
max_tokens = max_tokens ,
2382
2384
presence_penalty = presence_penalty ,
2383
2385
frequency_penalty = frequency_penalty ,
@@ -2507,7 +2509,7 @@ def base_function_calling(
2507
2509
min_p = min_p ,
2508
2510
typical_p = typical_p ,
2509
2511
stream = stream ,
2510
- stop = ["</s >" ],
2512
+ stop = stop + ["</done >" ],
2511
2513
logprobs = top_logprobs if logprobs else None ,
2512
2514
max_tokens = None ,
2513
2515
presence_penalty = presence_penalty ,
@@ -2532,7 +2534,7 @@ def base_function_calling(
2532
2534
completions : List [llama_types .CreateCompletionResponse ] = []
2533
2535
completions_tool_name : List [str ] = []
2534
2536
while tool is not None :
2535
- prompt += f"functions.{ tool_name } :\n "
2537
+ prompt += f"{ role_prefix } functions.{ tool_name } :{ role_suffix } "
2536
2538
try :
2537
2539
grammar = llama_grammar .LlamaGrammar .from_json_schema (
2538
2540
json .dumps (tool ["function" ]["parameters" ]), verbose = llama .verbose
@@ -2570,7 +2572,8 @@ def base_function_calling(
2570
2572
completion_or_chunks = cast (llama_types .CreateCompletionResponse , completion_or_chunks )
2571
2573
completions .append (completion_or_chunks )
2572
2574
completions_tool_name .append (tool_name )
2573
- prompt += completion_or_chunks ["choices" ][0 ]["text" ]
2575
+ out = completion_or_chunks ["choices" ][0 ]["text" ]
2576
+ prompt += f"{ role_prefix } { out } { role_suffix } "
2574
2577
print (prompt )
2575
2578
prompt += "\n "
2576
2579
response = llama .create_completion (
@@ -2858,3 +2861,84 @@ def vicuna_function_calling(
2858
2861
)
2859
2862
return base_function_calling (end_token = "</s>" ,
2860
2863
** locals ())
2864
+
2865
+ @register_chat_completion_handler ("llama3-function-calling" )
2866
+ def llama3_function_calling (
2867
+ llama : llama .Llama ,
2868
+ messages : List [llama_types .ChatCompletionRequestMessage ],
2869
+ functions : Optional [List [llama_types .ChatCompletionFunction ]] = None ,
2870
+ function_call : Optional [llama_types .ChatCompletionRequestFunctionCall ] = None ,
2871
+ tools : Optional [List [llama_types .ChatCompletionTool ]] = None ,
2872
+ tool_choice : Optional [llama_types .ChatCompletionToolChoiceOption ] = None ,
2873
+ temperature : float = 0.2 ,
2874
+ top_p : float = 0.95 ,
2875
+ top_k : int = 40 ,
2876
+ min_p : float = 0.05 ,
2877
+ typical_p : float = 1.0 ,
2878
+ stream : bool = False ,
2879
+ stop : Optional [Union [str , List [str ]]] = [],
2880
+ response_format : Optional [llama_types .ChatCompletionRequestResponseFormat ] = None ,
2881
+ max_tokens : Optional [int ] = None ,
2882
+ presence_penalty : float = 0.0 ,
2883
+ frequency_penalty : float = 0.0 ,
2884
+ repeat_penalty : float = 1.1 ,
2885
+ tfs_z : float = 1.0 ,
2886
+ mirostat_mode : int = 0 ,
2887
+ mirostat_tau : float = 5.0 ,
2888
+ mirostat_eta : float = 0.1 ,
2889
+ model : Optional [str ] = None ,
2890
+ logits_processor : Optional [llama .LogitsProcessorList ] = None ,
2891
+ grammar : Optional [llama .LlamaGrammar ] = None ,
2892
+ logprobs : Optional [bool ] = None ,
2893
+ top_logprobs : Optional [int ] = None ,
2894
+ ** kwargs , # type: ignore
2895
+ ) -> Union [
2896
+ llama_types .CreateChatCompletionResponse ,
2897
+ Iterator [llama_types .CreateChatCompletionStreamResponse ],
2898
+ ]:
2899
+ function_calling_template = (
2900
+ "<|begin_of_text|>"
2901
+ "{% if tool_calls %}"
2902
+ "<|start_header_id|>system<|end_header_id|>\n \n "
2903
+ "{% for message in messages %}"
2904
+ "{% if message.role == 'system' %}"
2905
+ "{{ message.content }}"
2906
+ "{% endif %}"
2907
+ "{% endfor %}"
2908
+ "You have access to the following functions to help you respond to users messages: \n "
2909
+ "{% for tool in tools %}"
2910
+ "\n functions.{{ tool.function.name }}:\n "
2911
+ "{{ tool.function.parameters | tojson }}"
2912
+ "\n {% endfor %}"
2913
+ "\n You can respond to users messages with either a single message or one or more function calls. Never both. Prioritize function calls over messages."
2914
+ "\n When we have a function response, bring it to the user."
2915
+ "\n To respond with a message begin the message with 'message:'"
2916
+ '\n Example sending message: message: "Hello, how can I help you?"'
2917
+ "\n To respond with one or more function calls begin the message with 'functions.<function_name>:', use the following format:"
2918
+ "\n functions.<function_name>:"
2919
+ '\n { "arg1": "value1", "arg2": "value2" }'
2920
+ "\n functions.<function_name>:"
2921
+ '\n { "arg1": "value1", "arg2": "value2" }'
2922
+ "\n When you are done with the function calls, end the message with </done>."
2923
+ '\n Start your output with either message: or functions. <|eot_id|>\n '
2924
+ "{% endif %}"
2925
+ "{% for message in messages %}"
2926
+ "{% if message.role == 'tool' %}"
2927
+ "<|start_header_id|>assistant<|end_header_id|>\n \n "
2928
+ "Function response: {{ message.content | default('No response available') }}"
2929
+ "<|eot_id|>\n "
2930
+ "{% elif message.role == 'assistant' and message.function_call is defined%}"
2931
+ "<|start_header_id|>{{ message.role }}<|end_header_id|>"
2932
+ "Function called: {{ message.function_call.name | default('No name') }}\n "
2933
+ "Function argument: {{ message.function_call.arguments | default('No arguments') }}"
2934
+ "<|eot_id|>\n "
2935
+ "{% else %}"
2936
+ "<|start_header_id|>{{ message.role }}<|end_header_id|>"
2937
+ "{{ message.content }}"
2938
+ "<|eot_id|>\n "
2939
+ "{% endif %}"
2940
+ "{% endfor %}"
2941
+
2942
+ )
2943
+ return base_function_calling (end_token = "<|eot_id|>" , role_prefix = "<|start_header_id|>assistant<|end_header_id|>" , role_suffix = "<|eot_id|>" ,
2944
+ ** locals ())
0 commit comments