8000 convert functionary-v1 chat handler to use hf autotokenizer · notwa/llama-cpp-python@43b4529 · GitHub
[go: up one dir, main page]

Skip to content

Commit 43b4529

Browse files
committed
convert functionary-v1 chat handler to use hf autotokenizer
1 parent bf9e824 commit 43b4529

File tree

1 file changed

+28
-58
lines changed

1 file changed

+28
-58
lines changed

llama_cpp/llama_chat_format.py

Lines changed: 28 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -978,8 +978,8 @@ def format_saiga(
978978

979979
# Tricky chat formats that require custom chat handlers
980980

981-
@register_chat_completi 10000 on_handler("functionary")
982-
def functionary_chat_handler(
981+
@register_chat_completion_handler("functionary-v1")
982+
def functionary_v1_chat_handler(
983983
llama: llama.Llama,
984984
messages: List[llama_types.ChatCompletionRequestMessage],
985985
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
@@ -1008,6 +1008,12 @@ def functionary_chat_handler(
10081008
**kwargs, # type: ignore
10091009
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]:
10101010
SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"""
1011+
END_SYSTEM_TOKEN = "<|END_OF_SYSTEM|>"
1012+
END_USER_TOKEN = "<|END_OF_USER|>"
1013+
END_ASSISTANT_TOKEN = "<|END_OF_ASSISTANT|>"
1014+
END_FUNCTION_RESULT_TOKEN = "<|END_OF_FUNCTION_RESULT|>"
1015+
START_FUNCTION_CALL_TOKEN = "<|START_OF_FUNCTION_CALL|>"
1016+
END_FUNCTION_CALL_TOKEN = "<|END_OF_FUNCTION_CALL|>"
10111017

10121018
def generate_type_definition(
10131019
param: Dict[str, llama_types.JsonType], indent_level: int, shared_defs
@@ -1079,22 +1085,23 @@ def generate_schema_from_functions(functions, namespace="functions") -> str:
10791085
parameters = function.get("parameters", {})
10801086
required_params = parameters.get("required", [])
10811087

1082-
schema += f" // {description}\n"
1083-
schema += f" type {function_name} = (_: {{\n"
1088+
schema += f"// {description}\n"
1089+
schema += f"type {function_name} = (_: {{\n"
10841090

10851091
for param_name, param in parameters.get("properties", {}).items():
10861092
param_description = param.get("description", "")
10871093
param_type = generate_type_definition(param, 2, shared_definitions)
10881094
optional_indicator = "" if param_name in required_params else "?"
1089-
schema += f" // {param_description}\n"
1090-
schema += f" {param_name}{optional_indicator}: {param_type},\n"
1091-
schema += " }) => any;\n\n"
1095+
schema += f"// {param_description}\n"
1096+
schema += f"{param_name}{optional_indicator}: {param_type},\n"
1097+
schema += "}) => any;\n\n"
10921098

1093-
schema += "}} // namespace {}\n".format(namespace)
1099+
schema += "}} // namespace {}".format(namespace)
10941100
return schema
10951101

10961102
def prepare_messages_for_inference(
10971103
messages: List[llama_types.ChatCompletionRequestMessage],
1104+
tokenizer: AutoTokenizer,
10981105
functions: Optional[List[llama_types.ChatCompletionFunctions]] = None,
10991106
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
11001107
):
@@ -1105,8 +1112,7 @@ def prepare_messages_for_inference(
11051112
role="system", content=generate_schema_from_functions(functions)
11061113
)
11071114
)
1108-
1109-
if tools is not None:
1115+
elif tools is not None:
11101116
all_messages.append(
11111117
llama_types.ChatCompletionRequestSystemMessage(
11121118
role="system",
@@ -1136,49 +1142,8 @@ def prepare_messages_for_inference(
11361142
"name"
11371143
] = f"functions.{message['function_call']['name']}"
11381144
all_messages.append(message)
1139-
1140-
all_messages.append(
1141-
llama_types.ChatCompletionRequestAssistantMessage(
1142-
role="assistant", content=None
1143-
)
1144-
)
1145-
1146-
def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
1147-
if msg["role"] == "system":
1148-
return f"system:\n{msg['content']}\n"
1149-
1150-
elif msg["role"] == "function" and "name" in msg:
1151-
return f"function name={msg['name']}:\n{msg['content']}\n"
1152-
elif msg["role"] == "function" and "function_call" in msg:
1153-
return f"function name={msg['function_call']['name']}:\n{msg['function_call']['arguments']}\n"
1154-
elif msg["role"] == "tool":
1155-
if msg["content"] is not None:
1156-
return f"function name={msg['tool_call_id']}:\n{msg['content']}\n"
1157-
else:
1158-
return f"function name={msg['tool_call_id']}\n"
1159-
elif msg["role"] == "user":
1160-
if msg["content"] is None:
1161-
return "user:\n</s></s>\n"
1162-
else:
1163-
return f"user:\n</s>{msg['content']}</s>\n"
1164-
elif msg["role"] == "assistant":
1165-
if msg["content"] is not None and "function_call" in msg:
1166-
return f"assistant:\n{msg['content']}\nassistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}</s>\n"
1167-
elif "function_call" in msg:
1168-
return f"assistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}</s>\n"
1169-
elif "tool_calls" in msg and len(msg["tool_calls"]) > 0:
1170-
for tool_call in msg[
1171-
"tool_calls"
1172-
]: # NOTE: probably doesn't work with the functionary model
1173-
return f"assistant to={tool_call['id']}:\n{tool_call['function']['arguments']}</s>\n"
1174-
elif msg["content"] is None:
1175-
return "assistant"
1176-
else:
1177-
return f"assistant:\n{msg['content']}\n"
1178-
else:
1179-
raise ValueError(f"Unsupported role: {msg['role']}")
1180-
1181-
return "".join([message_to_str(msg) for msg in all_messages])
1145+
1146+
return tokenizer.apply_chat_template(all_messages, tokenize=False) + "assistant:\n"
11821147

11831148
if tools is not None:
11841149
functions = [tool["function"] for tool in tools if tool["type"] == "function"]
@@ -1187,19 +1152,24 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
11871152
function_call = (
11881153
tool_choice if isinstance(tool_choice, str) else tool_choice["function"]
11891154
)
1155+
1156+
from transformers import AutoTokenizer
1157+
1158+
tokenizer_path = os.path.dirname(llama.model_path)
1159+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
11901160

1191-
prompt = prepare_messages_for_inference(messages, functions, tools)
1161+
prompt = prepare_messages_for_inference(messages, tokenizer, functions, tools)
11921162

11931163
if function_call is None and (functions is None or len(functions) == 0):
11941164
completion_or_completion_chunks = llama.create_completion(
1195-
prompt=prompt + ":\n",
1165+
prompt=prompt,
11961166
temperature=temperature,
11971167
top_p=top_p,
11981168
top_k=top_k,
11991169
min_p=min_p,
12001170
typical_p=typical_p,
12011171
stream=stream,
1202-
stop=["user:", "</s>"],
1172+
stop=["user:", END_ASSISTANT_TOKEN],
12031173
max_tokens=max_tokens,
12041174
presence_penalty=presence_penalty,
12051175
frequency_penalty=frequency_penalty,
@@ -1217,9 +1187,9 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
12171187
if function_call is None or (
12181188
isinstance(function_call, str) and function_call == "auto"
12191189
):
1220-
stop = "\n"
1190+
stop = [END_ASSISTANT_TOKEN, END_FUNCTION_CALL_TOKEN]
12211191
completion: llama_types.Completion = llama.create_completion(
1222-
prompt=prompt, stop=stop, stream=False
1192+
prompt=prompt + ":\n", stop=stop, stream=False, max_tokens=max_tokens
12231193
) # type: ignore
12241194
completion_text = completion["choices"][0]["text"]
12251195
# strip " to=functions." and ending ":"

0 commit comments

Comments
 (0)
0