8000 llama3 · themrzmaster/llama-cpp-python@4daadb7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4daadb7

Browse files
committed
llama3
1 parent 0080f68 commit 4daadb7

File tree

1 file changed

+88
-4
lines changed

1 file changed

+88
-4
lines changed

llama_cpp/llama_chat_format.py

Lines changed: 88 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2316,6 +2316,8 @@ def base_function_calling(
23162316
grammar: Optional[llama.LlamaGrammar] = None,
23172317
logprobs: Optional[bool] = None,
23182318
top_logprobs: Optional[int] = None,
2319+
role_prefix: Optional[str] = "",
2320+
role_suffix: Optional[str] = "",
23192321
**kwargs, # type: ignore
23202322
) -> Union[
23212323
llama_types.CreateChatCompletionResponse,
@@ -2377,7 +2379,7 @@ def base_function_calling(
23772379
min_p=min_p,
23782380
typical_p=typical_p,
23792381
stream=stream,
2380-
stop=stop,
2382+
stop=stop + ["</done>"],
23812383
max_tokens=max_tokens,
23822384
presence_penalty=presence_penalty,
23832385
frequency_penalty=frequency_penalty,
@@ -2507,7 +2509,7 @@ def base_function_calling(
25072509
min_p=min_p,
25082510
typical_p=typical_p,
25092511
stream=stream,
2510-
stop=["</s>"],
2512+
stop=stop+ ["</done>"],
25112513
logprobs=top_logprobs if logprobs else None,
25122514
max_tokens=None,
25132515
presence_penalty=presence_penalty,
@@ -2532,7 +2534,7 @@ def base_function_calling(
25322534
completions: List[llama_types.CreateCompletionResponse] = []
25332535
completions_tool_name: List[str] = []
25342536
while tool is not None:
2535-
prompt += f"functions.{tool_name}:\n"
2537+
prompt += f"{role_prefix}functions.{tool_name}:{role_suffix}"
25362538
try:
25372539
grammar = llama_grammar.LlamaGrammar.from_json_schema(
25382540
json.dumps(tool["function"]["parameters"]), verbose=llama.verbose
@@ -2570,7 +2572,8 @@ def base_function_calling(
25702572
completion_or_chunks = cast(llama_types.CreateCompletionResponse, completion_or_chunks)
25712573
completions.append(completion_or_chunks)
25722574
completions_tool_name.append(tool_name)
2573-
prompt += completion_or_chunks["choices"][0]["text"]
2575+
out = completion_or_chunks["choices"][0]["text"]
2576+
prompt += f"{role_prefix}{out}{role_suffix}"
25742577
print(prompt)
25752578
prompt += "\n"
25762579
response = llama.create_completion(
@@ -2858,3 +2861,84 @@ def vicuna_function_calling(
28582861
)
28592862
return base_function_calling(end_token="</s>",
28602863
**locals())
2864+
2865+
@register_chat_completion_handler("llama3-function-calling")
2866+
def llama3_function_calling(
2867+
llama: llama.Llama,
2868+
messages: List[llama_types.ChatCompletionRequestMessage],
2869+
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
2870+
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
2871+
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
2872+
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
2873+
temperature: float = 0.2,
2874+
top_p: float = 0.95,
2875+
top_k: int = 40,
2876+
min_p: float = 0.05,
2877+
typical_p: float = 1.0,
2878+
stream: bool = False,
2879+
stop: Optional[Union[str, List[str]]] = [],
2880+
response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None,
2881+
max_tokens: Optional[int] = None,
2882+
presence_penalty: float = 0.0,
2883+
frequency_penalty: float = 0.0,
2884+
repeat_penalty: float = 1.1,
2885+
tfs_z: float = 1.0,
2886+
mirostat_mode: int = 0,
2887+
mirostat_tau: float = 5.0,
2888+
mirostat_eta: float = 0.1,
2889+
model: Optional[str] = None,
2890+
logits_processor: Optional[llama.LogitsProcessorList] = None,
2891+
grammar: Optional[llama.LlamaGrammar] = None,
2892+
logprobs: Optional[bool] = None,
2893+
top_logprobs: Optional[int] = None,
2894+
**kwargs, # type: ignore
2895+
) -> Union[
2896+
llama_types.CreateChatCompletionResponse,
2897+
Iterator[llama_types.CreateChatCompletionStreamResponse],
2898+
]:
2899+
function_calling_template = (
2900+
"<|begin_of_text|>"
2901+
"{% if tool_calls %}"
2902+
"<|start_header_id|>system<|end_header_id|>\n\n"
2903+
"{% for message in messages %}"
2904+
"{% if message.role == 'system' %}"
2905+
"{{ message.content }}"
2906+
"{% endif %}"
2907+
"{% endfor %}"
2908+
"You have access to the following functions to help you respond to users messages: \n"
2909+
"{% for tool in tools %}"
2910+
"\nfunctions.{{ tool.function.name }}:\n"
2911+
"{{ tool.function.parameters | tojson }}"
2912+
"\n{% endfor %}"
2913+
"\nYou can respond to users messages with either a single message or one or more function calls. Never both. Prioritize function calls over messages."
2914+
"\n When we have a function response, bring it to the user."
2915+
"\nTo respond with a message begin the message with 'message:'"
2916+
'\n Example sending message: message: "Hello, how can I help you?"'
2917+
"\nTo respond with one or more function calls begin the message with 'functions.<function_name>:', use the following format:"
2918+
"\nfunctions.<function_name>:"
2919+
'\n{ "arg1": "value1", "arg2": "value2" }'
2920+
"\nfunctions.<function_name>:"
2921+
'\n{ "arg1": "value1", "arg2": "value2" }'
2922+
"\nWhen you are done with the function calls, end the message with </done>."
2923+
'\nStart your output with either message: or functions. <|eot_id|>\n'
2924+
"{% endif %}"
2925+
"{% for message in messages %}"
2926+
"{% if message.role == 'tool' %}"
2927+
"<|start_header_id|>assistant<|end_header_id|>\n\n"
2928+
"Function response: {{ message.content | default('No response available') }}"
2929+
"<|eot_id|>\n"
2930+
"{% elif message.role == 'assistant' and message.function_call is defined%}"
2931+
"<|start_header_id|>{{ message.role }}<|end_header_id|>"
2932+
"Function called: {{ message.function_call.name | default('No name') }}\n"
2933+
"Function argument: {{ message.function_call.arguments | default('No arguments') }}"
2934+
"<|eot_id|>\n"
2935+
"{% else %}"
2936+
"<|start_header_id|>{{ message.role }}<|end_header_id|>"
2937+
"{{ message.content }}"
2938+
"<|eot_id|>\n"
2939+
"{% endif %}"
2940+
"{% endfor %}"
2941+
2942+
)
2943+
return base_function_calling(end_token="<|eot_id|>", role_prefix="<|start_header_id|>assistant<|end_header_id|>", role_suffix="<|eot_id|>",
2944+
**locals())

0 commit comments

Comments
 (0)
0