@@ -532,6 +532,7 @@ def format_phind(
532
532
_prompt = _format_add_colon_single (_system_message , _messages , _sep )
533
533
return ChatFormatterResponse (prompt = _prompt )
534
534
535
+
535
536
@register_chat_format ("intel" )
536
537
def format_intel (
537
538
messages : List [llama_types .ChatCompletionRequestMessage ],
@@ -588,6 +589,7 @@ def format_mistrallite(
588
589
_prompt = _format_no_colon_single (system_message , _messages , _sep )
589
590
return ChatFormatterResponse (prompt = _prompt )
590
591
592
+
591
593
@register_chat_format ("chatml" )
592
594
def format_chatml (
593
595
messages : List [llama_types .ChatCompletionRequestMessage ],
@@ -604,6 +606,7 @@ def format_chatml(
604
606
_prompt = _format_chatml (system_message , _messages , _sep )
605
607
return ChatFormatterResponse (prompt = _prompt , stop = _sep )
606
608
609
+
607
610
@register_chat_format ("openchat" )
608
611
def format_openchat (
609
612
messages : List [llama_types .ChatCompletionRequestMessage ],
@@ -612,7 +615,9 @@ def format_openchat(
612
615
system_template = "{system_message}<|end_of_turn|>"
613
616
system_message = _get_system_message (messages )
614
617
system_message = system_template .format (system_message = system_message )
615
- _roles = dict (user = "GPT4 Correct User: " , assistant = "<|end_of_turn|>GPT4 Correct Assistant: " )
618
+ _roles = dict (
619
+ user = "GPT4 Correct User: " , assistant = "<|end_of_turn|>GPT4 Correct Assistant: "
620
+ )
616
621
_sep = "<|end_of_turn|>"
617
622
_messages = _map_roles (messages , _roles )
618
623
_messages .append ((_roles ["assistant" ], None ))
@@ -651,46 +656,60 @@ def functionary_chat_handler(
651
656
) -> Union [llama_types .ChatCompletion , Iterator [llama_types .ChatCompletionChunk ]]:
652
657
SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"""
653
658
654
- def generate_type_definition (param : Dict [str , llama_types .JsonType ], indent_level : int , shared_defs ) -> str :
655
- indent = ' ' * indent_level
656
- if '$ref' in param :
659
+ def generate_type_definition (
660
+ param : Dict [str , llama_types .JsonType ], indent_level : int , shared_defs
661
+ ) -> str :
662
+ indent = " " * indent_level
663
+ if "$ref" in param :
657
664
# Reference to a shared definition
658
- ref_name = param ['$ref' ].split ('/' )[- 1 ] # Extract the type name from the reference
665
+ ref_name = param ["$ref" ].split ("/" )[
666
+ - 1
667
+ ] # Extract the type name from the reference
659
668
return ref_name
660
- elif param .get (' type' ) == ' array' :
661
- items = param .get (' items' , {})
669
+ elif param .get (" type" ) == " array" :
670
+ items = param .get (" items" , {})
662
671
item_type = generate_type_definition (items , indent_level + 1 , shared_defs )
663
672
return f"Array<{ item_type } >"
664
- elif param .get (' type' ) == ' object' :
665
- properties = param .get (' properties' , {})
673
+ elif param .get (" type" ) == " object" :
674
+ properties = param .get (" properties" , {})
666
675
nested_schema = "{\n "
667
676
for nested_param_name , nested_param in properties .items ():
668
- nested_param_type = generate_type_definition (nested_param , indent_level + 1 , shared_defs )
669
- nested_schema += f"{ indent } { nested_param_name } : { nested_param_type } ,\n "
677
+ nested_param_type = generate_type_definition (
678
+ nested_param , indent_level + 1 , shared_defs
679
+ )
680
+ nested_schema += (
681
+ f"{ indent } { nested_param_name } : { nested_param_type } ,\n "
682
+ )
670
683
nested_schema += indent + "}"
671
684
return nested_schema
672
- elif ' enum' in param :
685
+ elif " enum" in param :
673
686
# Enum type
674
- return " | " .join ([f'"{ enum_value } "' for enum_value in param [' enum' ]])
687
+ return " | " .join ([f'"{ enum_value } "' for enum_value in param [" enum" ]])
675
688
else :
676
689
# Simple type
677
- return param .get (' type' , ' any' )
690
+ return param .get (" type" , " any" )
678
691
679
692
def generate_shared_definitions (shared_defs , indent_level : int ) -> str :
680
- indent = ' ' * indent_level
693
+ indent = " " * indent_level
681
694
shared_definitions = ""
682
695
for def_name , def_properties in shared_defs .items ():
683
696
shared_definitions += f"{ indent } type { def_name } = "
684
- if def_properties .get ('type' ) == 'object' :
685
- shared_definitions += generate_type_definition (def_properties , indent_level , shared_defs )
686
- elif 'enum' in def_properties :
697
+ if def_properties .get ("type" ) == "object" :
698
+ shared_definitions += generate_type_definition (
699
+ def_properties , indent_level , shared_defs
700
+ )
701
+ elif "enum" in def_properties :
687
702
# Enum type
688
- shared_definitions += " | " .join ([f'"{ enum_value } "' for enum_value in def_properties ['enum' ]])
703
+ shared_definitions += " | " .join (
704
+ [f'"{ enum_value } "' for enum_value in def_properties ["enum" ]]
705
+ )
689
706
shared_definitions += ";\n "
690
707
return shared_definitions
691
708
692
709
def generate_schema_from_functions (functions , namespace = "functions" ) -> str :
693
- schema = "// Supported function definitions that should be called when necessary.\n "
710
+ schema = (
711
+ "// Supported function definitions that should be called when necessary.\n "
712
+ )
694
713
schema += f"namespace { namespace } {{\n \n "
695
714
696
715
# Generate shared definitions
@@ -706,10 +725,10 @@ def generate_schema_from_functions(functions, namespace="functions") -> str:
706
725
description = function .get ("description" , "" )
707
726
parameters = function .get ("parameters" , {})
708
727
required_params = parameters .get ("required" , [])
709
-
728
+
710
729
schema += f" // { description } \n "
711
730
schema += f" type { function_name } = (_: {{\n "
712
-
731
+
713
732
for param_name , param in parameters .get ("properties" , {}).items ():
714
733
param_description = param .get ("description" , "" )
715
734
param_type = generate_type_definition (param , 2 , shared_definitions )
EED3
@@ -733,13 +752,18 @@ def prepare_messages_for_inference(
733
752
role = "system" , content = generate_schema_from_functions (functions )
734
753
)
735
754
)
736
-
755
+
737
756
if tools is not None :
738
757
all_messages .append (
739
758
llama_types .ChatCompletionRequestSystemMessage (
740
- role = "system" , content = generate_schema_from_functions (
741
- [tool ["function" ] for tool in tools if tool ["type" ] == "function" ]
742
- )
759
+ role = "system" ,
760
+ content = generate_schema_from_functions (
761
+ [
762
+ tool ["function" ]
763
+ for tool in tools
764
+ if tool ["type" ] == "function"
765
+ ]
766
+ ),
743
767
)
744
768
)
745
769
@@ -790,7 +814,9 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
790
814
elif "function_call" in msg :
791
815
return f"assistant to={ msg ['function_call' ]['name' ]} :\n { msg ['function_call' ]['arguments' ]} </s>\n "
792
816
elif "tool_calls" in msg and len (msg ["tool_calls" ]) > 0 :
793
- for tool_call in msg ["tool_calls" ]: # NOTE: probably doesn't work with the functionary model
817
+ for tool_call in msg [
818
+ "tool_calls"
819
+ ]: # NOTE: probably doesn't work with the functionary model
794
820
return f"assistant to={ tool_call ['id' ]} :\n { tool_call ['function' ]['arguments' ]} </s>\n "
795
821
elif msg ["content" ] is None :
796
822
return "assistant"
@@ -800,12 +826,14 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
800
826
raise ValueError (f"Unsupported role: { msg ['role' ]} " )
801
827
802
828
return "" .join ([message_to_str (msg ) for msg in all_messages ])
803
-
829
+
804
830
if tools is not None :
805
831
functions = [tool ["function" ] for tool in tools if tool ["type" ] == "function" ]
806
-
832
+
807
833
if tool_choice is not None :
808
- function_call = tool_choice if isinstance (tool_choice , str ) else tool_choice ["function" ]
834
+ function_call = (
835
+ tool_choice if isinstance (tool_choice , str ) else tool_choice ["function" ]
836
+ )
809
837
810
838
prompt = prepare_messages_for_inference (messages , functions , tools )
811
839
@@ -861,19 +889,27 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
861
889
if tool ["type" ] == "function" and tool ["function" ]["name" ] == function_call :
862
890
function_body = tool ["function" ]["parameters" ]
863
891
break
864
-
892
+
865
893
if function_body is not None :
866
894
try :
867
895
with suppress_stdout_stderr (disable = llama .verbose ):
868
- grammar_text = llama_grammar .json_schema_to_gbnf (json .dumps (function_body ))
869
- grammar = llama_grammar .LlamaGrammar .from_string (llama_grammar .json_schema_to_gbnf (json .dumps (function_body )))
896
+ grammar_text = llama_grammar .json_schema_to_gbnf (
897
+ json .dumps (function_body )
898
+ )
899
+ grammar = llama_grammar .LlamaGrammar .from_string (
900
+ llama_grammar .json_schema_to_gbnf (json .dumps (function_body ))
901
+ )
870
902
print (grammar_text )
871
903
except Exception as e :
872
904
if llama .verbose :
873
- print ("Failed to parse function body as JSON schema, falling back to default grammar" )
905
+ print (
906
+ "Failed to parse function body as JSON schema, falling back to default grammar"
907
+ )
874
908
print (e )
875
909
with suppress_stdout_stderr (disable = llama .verbose ):
876
- grammar = llama_grammar .LlamaGrammar .from_string (llama_grammar .JSON_GBNF )
910
+ grammar = llama_grammar .LlamaGrammar .from_string (
911
+ llama_grammar .JSON_GBNF
912
+ )
877
913
else :
878
914
with suppress_stdout_stderr (disable = llama .verbose ):
879
915
grammar = llama_grammar .LlamaGrammar .from_string (llama_grammar .JSON_GBNF )
@@ -929,9 +965,9 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
929
965
"function" : {
930
966
"name" : function_call ,
10000
931
967
"arguments" : completion ["choices" ][0 ]["text" ],
932
- }
968
+ },
933
969
}
934
- ]
970
+ ],
935
971
},
936
972
"finish_reason" : "tool_calls" ,
937
973
}
0 commit comments