@@ -978,8 +978,8 @@ def format_saiga(
978
978
979
979
# Tricky chat formats that require custom chat handlers
980
980
981
- @register_chat_completi
10000
on_handler ("functionary" )
982
- def functionary_chat_handler (
981
+ @register_chat_completion_handler ("functionary-v1 " )
982
+ def functionary_v1_chat_handler (
983
983
llama : llama .Llama ,
984
984
messages : List [llama_types .ChatCompletionRequestMessage ],
985
985
functions : Optional [List [llama_types .ChatCompletionFunction ]] = None ,
@@ -1008,6 +1008,12 @@ def functionary_chat_handler(
1008
1008
** kwargs , # type: ignore
1009
1009
) -> Union [llama_types .ChatCompletion , Iterator [llama_types .ChatCompletionChunk ]]:
1010
1010
SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"""
1011
+ END_SYSTEM_TOKEN = "<|END_OF_SYSTEM|>"
1012
+ END_USER_TOKEN = "<|END_OF_USER|>"
1013
+ END_ASSISTANT_TOKEN = "<|END_OF_ASSISTANT|>"
1014
+ END_FUNCTION_RESULT_TOKEN = "<|END_OF_FUNCTION_RESULT|>"
1015
+ START_FUNCTION_CALL_TOKEN = "<|START_OF_FUNCTION_CALL|>"
1016
+ END_FUNCTION_CALL_TOKEN = "<|END_OF_FUNCTION_CALL|>"
1011
1017
1012
1018
def generate_type_definition (
1013
1019
param : Dict [str , llama_types .JsonType ], indent_level : int , shared_defs
@@ -1079,22 +1085,23 @@ def generate_schema_from_functions(functions, namespace="functions") -> str:
1079
1085
parameters = function .get ("parameters" , {})
1080
1086
required_params = parameters .get ("required" , [])
1081
1087
1082
- schema += f" // { description } \n "
1083
- schema += f" type { function_name } = (_: {{\n "
1088
+ schema += f"// { description } \n "
1089
+ schema += f"type { function_name } = (_: {{\n "
1084
1090
1085
1091
for param_name , param in parameters .get ("properties" , {}).items ():
1086
1092
param_description = param .get ("description" , "" )
1087
1093
param_type = generate_type_definition (param , 2 , shared_definitions )
1088
1094
optional_indicator = "" if param_name in required_params else "?"
1089
- schema += f" // { param_description } \n "
1090
- schema += f" { param_name } { optional_indicator } : { param_type } ,\n "
1091
- schema += " }) => any;\n \n "
1095
+ schema += f"// { param_description } \n "
1096
+ schema += f"{ param_name } { optional_indicator } : { param_type } ,\n "
1097
+ schema += "}) => any;\n \n "
1092
1098
1093
- schema += "}} // namespace {}\n " .format (namespace )
1099
+ schema += "}} // namespace {}" .format (namespace )
1094
1100
return schema
1095
1101
1096
1102
def prepare_messages_for_inference (
1097
1103
messages : List [llama_types .ChatCompletionRequestMessage ],
1104
+ tokenizer : AutoTokenizer ,
1098
1105
functions : Optional [List [llama_types .ChatCompletionFunctions ]] = None ,
1099
1106
tools : Optional [List [llama_types .ChatCompletionTool ]] = None ,
1100
1107
):
@@ -1105,8 +1112,7 @@ def prepare_messages_for_inference(
1105
1112
role = "system" , content = generate_schema_from_functions (functions )
1106
1113
)
1107
1114
)
1108
-
1109
- if tools is not None :
1115
+ elif tools is not None :
1110
1116
all_messages .append (
1111
1117
llama_types .ChatCompletionRequestSystemMessage (
1112
1118
role = "system" ,
@@ -1136,49 +1142,8 @@ def prepare_messages_for_inference(
1136
1142
"name"
1137
1143
] = f"functions.{ message ['function_call' ]['name' ]} "
1138
1144
all_messages .append (message )
1139
-
1140
- all_messages .append (
1141
- llama_types .ChatCompletionRequestAssistantMessage (
1142
- role = "assistant" , content = None
1143
- )
1144
- )
1145
-
1146
- def message_to_str (msg : llama_types .ChatCompletionRequestMessage ):
1147
- if msg ["role" ] == "system" :
1148
- return f"system:\n { msg ['content' ]} \n "
1149
-
1150
- elif msg ["role" ] == "function" and "name" in msg :
1151
- return f"function name={ msg ['name' ]} :\n { msg ['content' ]} \n "
1152
- elif msg ["role" ] == "function" and "function_call" in msg :
1153
- return f"function name={ msg ['function_call' ]['name' ]} :\n { msg ['function_call' ]['arguments' ]} \n "
1154
- elif msg ["role" ] == "tool" :
1155
- if msg ["content" ] is not None :
1156
- return f"function name={ msg ['tool_call_id' ]} :\n { msg ['content' ]} \n "
1157
- else :
1158
- return f"function name={ msg ['tool_call_id' ]} \n "
1159
- elif msg ["role" ] == "user" :
1160
- if msg ["content" ] is None :
1161
- return "user:\n </s></s>\n "
1162
- else :
1163
- return f"user:\n </s>{ msg ['content' ]} </s>\n "
1164
- elif msg ["role" ] == "assistant" :
1165
- if msg ["content" ] is not None and "function_call" in msg :
1166
- return f"assistant:\n { msg ['content' ]} \n assistant to={ msg ['function_call' ]['name' ]} :\n { msg ['function_call' ]['arguments' ]} </s>\n "
1167
- elif "function_call" in msg :
1168
- return f"assistant to={ msg ['function_call' ]['name' ]} :\n { msg ['function_call' ]['arguments' ]} </s>\n "
1169
- elif "tool_calls" in msg and len (msg ["tool_calls" ]) > 0 :
1170
- for tool_call in msg [
1171
- "tool_calls"
1172
- ]: # NOTE: probably doesn't work with the functionary model
1173
- return f"assistant to={ tool_call ['id' ]} :\n { tool_call ['function' ]['arguments' ]} </s>\n "
1174
- elif msg ["content" ] is None :
1175
- return "assistant"
1176
- else :
1177
- return f"assistant:\n { msg ['content' ]} \n "
1178
- else :
1179
- raise ValueError (f"Unsupported role: { msg ['role' ]} " )
1180
-
1181
- return "" .join ([message_to_str (msg ) for msg in all_messages ])
1145
+
1146
+ return tokenizer .apply_chat_template (all_messages , tokenize = False ) + "assistant:\n "
1182
1147
1183
1148
if tools is not None :
1184
1149
functions = [tool ["function" ] for tool in tools if tool ["type" ] == "function" ]
@@ -1187,19 +1152,24 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
1187
1152
function_call = (
1188
1153
tool_choice if isinstance (tool_choice , str ) else tool_choice ["function" ]
1189
1154
)
1155
+
1156
+ from transformers import AutoTokenizer
1157
+
1158
+ tokenizer_path = os .path .dirname (llama .model_path )
1159
+ tokenizer = AutoTokenizer .from_pretrained (tokenizer_path )
1190
1160
1191
- prompt = prepare_messages_for_inference (messages , functions , tools )
1161
+ prompt = prepare_messages_for_inference (messages , tokenizer , functions , tools )
1192
1162
1193
1163
if function_call is None and (functions is None or len (functions ) == 0 ):
1194
1164
completion_or_completion_chunks = llama .create_completion (
1195
- prompt = prompt + ": \n " ,
1165
+ prompt = prompt ,
1196
1166
temperature = temperature ,
1197
1167
top_p = top_p ,
1198
1168
top_k = top_k ,
1199
1169
min_p = min_p ,
1200
1170
typical_p = typical_p ,
1201
1171
stream = stream ,
1202
- stop = ["user:" , "</s>" ],
1172
+ stop = ["user:" , END_ASSISTANT_TOKEN ],
1203
1173
max_tokens = max_tokens ,
1204
1174
presence_penalty = presence_penalty ,
1205
1175
frequency_penalty = frequency_penalty ,
@@ -1217,9 +1187,9 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
1217
1187
if function_call is None or (
1218
1188
isinstance (function_call , str ) and function_call == "auto"
1219
1189
):
1220
- stop = " \n "
1190
+ stop = [ END_ASSISTANT_TOKEN , END_FUNCTION_CALL_TOKEN ]
1221
1191
completion : llama_types .Completion = llama .create_completion (
1222
- prompt = prompt , stop = stop , stream = False
1192
+ prompt = prompt + ": \n " , stop = stop , stream = False , max_tokens = max_tokens
1223
1193
) # type: ignore
1224
1194
completion_text = completion ["choices" ][0 ]["text" ]
1225
1195
# strip " to=functions." and ending ":"
0 commit comments