@@ -780,6 +780,79 @@ def logit_bias_processor(
780
780
781
781
if seed is not None :
782
782
self ._ctx .set_rng_seed (seed )
783
+
784
+ def _completion_stream_response (text : str , logprobs_or_none : Optional [CompletionLogprobs ] = None , finish_reason : Optional [Literal ["stop" , "length" ]] = None ) -> CreateCompletionStreamResponse :
785
+ return {
786
+ "id" : completion_id ,
787
+ "object" : "text_completion" ,
788
+ "created" : created ,
789
+ "model" : model_name ,
790
+ "choices" : [
791
+ {
792
+ "text" : text ,
793
+ "index" : 0 ,
794
+ "logprobs" : logprobs_or_none ,
795
+ "finish_reason" : finish_reason ,
796
+ }
797
+ ],
798
+ }
799
+
800
+ def _completion_response (text : str , finish_reason : Literal ["stop" , "length" ], logprobs_or_none : Optional [CompletionLogprobs ] = None ) -> CreateCompletionResponse :
801
+ return {
802
+ "id" : completion_id ,
803
+ "object" : "text_completion" ,
804
+ "created" : created ,
805
+ "model" : model_name ,
806
+ "choices" : [
807
+ {
808
+ "text" : text ,
809
+ "index" : 0 ,
810
+ "logprobs" : logprobs_or_none ,
811
+ "finish_reason" : finish_reason ,
812
+ }
813
+ ],
814
+ "usage" : {
815
+ "prompt_tokens" : len (prompt_tokens ),
816
+ "completion_tokens" : len (completion_tokens ),
817
+ "total_tokens" : len (prompt_tokens ) + len (completion_tokens ),
818
+ },
819
+ }
820
+
821
+ def _logprobs_or_none (all_tokens : List [int ], all_token_strs : List [str ], all_logprobs : List [List [float ]], text_offset : int ) -> CompletionLogprobs :
822
+ tokens : List [str ] = []
823
+ text_offsets : List [int ] = []
824
+ token_logprobs : List [Optional [float ]] = []
825
+ top_logprobs : List [Optional [Dict [str , float ]]] = []
826
+
827
+ for token , token_str , token_logprob in zip (
828
+ all_tokens , all_token_strs , all_logprobs
829
+ ):
830
+ if token == self .token_bos ():
831
+ continue
832
+
833
+ text_offset += len (token_str )
834
+ sorted_logprobs = list (
835
+ sorted (
836
+ zip (token_logprob , range (len (token_logprob ))), reverse = True
837
+ )
838
+ )
839
+ top_logprob = {
840
+ self .detokenize ([i ]).decode ("utf-8" , errors = "ignore" ): logprob
841
+ for logprob , i in sorted_logprobs [:logprobs ]
842
+ }
843
+ top_logprob .update ({token_str : token_logprob [int (token )]})
844
+
845
+ tokens .append (token_str )
846
+ text_offsets .append (text_offset )
847
+ token_logprobs .append (token_logprob [int (token )])
848
+ top_logprobs .append (top_logprob )
849
+
850
+ return {
851
+ "tokens" : tokens ,
852
+ "text_offset" : text_offsets ,
853
+ "token_logprobs" : token_logprobs ,
854
+ "top_logprobs" : top_logprobs ,
855
+ }
783
856
784
857
finish_reason = "length"
785
858
multibyte_fix = 0
@@ -868,10 +941,10 @@ def logit_bias_processor(
868
941
)
869
942
token_offset = len (prompt_tokens ) + returned_tokens
870
943
logits = self ._scores [token_offset - 1 , :].tolist ()
871
- current_logprobs = Llama .logits_to_logprobs (logits )
944
+ token_logprob = Llama .logits_to_logprobs (logits )
872
945
sorted_logprobs = list (
873
946
sorted (
874
- zip (current_logprobs , range (len (current_logprobs ))),
947
+ zip (token_logprob , range (len (token_logprob ))),
875
948
reverse = True ,
876
949
)
877
950
)
@@ -881,34 +954,22 @@ def logit_bias_processor(
881
954
): logprob
882
955
for logprob , i in sorted_logprobs [:logprobs ]
883
956
}
884
- top_logprob .update ({token_str : current_logprobs [int (token )]})
957
+ top_logprob .update ({token_str : token_logprob [int (token )]})
885
958
logprobs_or_none = {
886
959
"tokens" : [
887
960
self .detokenize ([token ]).decode (
888
961
"utf-8" , errors = "ignore"
889
962
)
890
963
],
891
964
"text_offset" : [text_offset ],
892
- "token_logprobs" : [current_logprobs [int (token )]],
965
+ "token_logprobs" : [token_logprob [int (token )]],
893
966
"top_logprobs" : [top_logprob ],
894
967
}
895
968
returned_tokens += 1
896
- yield {
897
- "id" : completion_id ,
898
- "object" : "text_completion" ,
899
- "created" : created ,
900
- "model" : model_name ,
901
- "choices" : [
902
- {
903
- "text" : self .detokenize ([token ]).decode (
904
- "utf-8" , errors = "ignore"
905
- ),
906
- "index" : 0 ,
907
- "logprobs" : logprobs_or_none ,
908
- "finish_reason" : None ,
909
- }
910
- ],
911
- }
969
+ yield _completion_stream_response (
970
+ self .detokenize ([token ]).decode ("utf-8" , errors = "ignore" ),
971
+ logprobs_or_none ,
972
+ )
912
973
else :
913
974
while len (remaining_tokens ) > 0 :
914
975
decode_success = False
@@ -933,20 +994,7 @@ def logit_bias_processor(
933
994
remaining_tokens = remaining_tokens [i :]
934
995
returned_tokens += i
935
996
936
- yield {
937
- "id" : completion_id ,
938
- "object" : "text_completion" ,
939
- "created" : created ,
940
- "model" : model_name ,
941
- "choices" : [
942
- {
943
- "text" : ts ,
944
- "index" : 0 ,
945
- "logprobs" : None ,
946
- "finish_reason" : None ,
947
- }
948
- ],
949
- }
997
+ yield _completion_stream_response (text = ts )
950
998
951
999
if len (completion_tokens ) >= max_tokens :
952
1000
text = self .detokenize (completion_tokens )
@@ -986,25 +1034,22 @@ def logit_bias_processor(
986
1034
self .detokenize (completion_tokens [:returned_tokens ])
987
1035
)
988
1036
token_offset = len (prompt_tokens ) + returned_tokens - 1
989
- logits = self ._scores [token_offset , :].tolist ()
990
- current_logprobs = Llama .logits_to_logprobs (logits )
1037
+ token_logprob = Llama .logits_to_logprobs (self ._scores [token_offset , :].tolist ())
991
1038
sorted_logprobs = list (
992
1039
sorted (
993
- zip (current_logprobs , range (len (current_logprobs ))),
1040
+ zip (token_logprob , range (len (token_logprob ))),
994
1041
reverse = True ,
995
1042
)
996
1043
)
997
1044
top_logprob = {
998
1045
self .detokenize ([i ]).decode ("utf-8" , errors = "ignore" ): logprob
999
1046
for logprob , i in sorted_logprobs [:logprobs ]
1000
1047
}
1001
- top_logprob .update ({token_str : current_logprobs [int (token )]})
1048
+ top_logprob .update ({token_str : token_logprob [int (token )]})
1002
1049
logprobs_or_none = {
1003
- "tokens" : [
1004
- self .detokenize ([token ]).decode ("utf-8" , errors = "ignore" )
1005
- ],
1050
+ "tokens" : [token_str ],
1006
1051
"text_offset" : [text_offset ],
1007
- "token_logprobs" : [current_logprobs [int (token )]],
1052
+ "token_logprobs" : [token_logprob [int (token )]],
1008
1053
"top_logprobs" : [top_logprob ],
1009
1054
}
1010
1055
@@ -1013,54 +1058,17 @@ def logit_bias_processor(
1013
1058
if token_end_position == end - 1 :
1014
1059
break
1015
1060
returned_tokens += 1
1016
- yield {
1017
- "id" : completion_id ,
1018
- "object" : "text_completion" ,
1019
- "created" : created ,
1020
- "model" : model_name ,
1021
- "choices" : [
1022
- {
1023
- "text" : last_text [
1024
- : len (last_text ) - (token_end_position - end )
1025
- ].decode ("utf-8" , errors = "ignore" ),
1026
- "index" : 0 ,
1027
- "logprobs" : logprobs_or_none ,
1028
- "finish_reason" : None ,
1029
- }
1030
- ],
1031
- }
1061
+ yield _completion_stream_response (
1062
+ text = last_text [: len (last_text ) - (token_end_position - end )].decode ("utf-8" , errors = "ignore" ), logprobs_or_none = logprobs_or_none
1063
+ )
1032
1064
break
1033
1065
returned_tokens += 1
1034
- yield {
1035
- "id" : completion_id ,
1036
- "object" : "text_completion" ,
1037
- "created" : created ,
1038
- "model" : model_name ,
1039
- "choices" : [
1040
- {
1041
- "text" : self .detokenize ([token ]).decode (
1042
- "utf-8" , errors = "ignore"
1043
- ),
1044
- "index" : 0 ,
1045
- "logprobs" : logprobs_or_none ,
1046
- "finish_reason" : None ,
1047
- }
1048
- ],
1049
- }
1050
- yield {
1051
- "id" : completion_id ,
1052
- "object" : "text_completion" ,
1053
- "created" : created ,
1054
- "model" : model_name ,
1055
- "choices" : [
1056
- {
1057
- "text" : "" ,
1058
- "index" : 0 ,
1059
- "logprobs" : None ,
1060
- "finish_reason" : finish_reason ,
1061
- }
1062
- ],
1063
- }
1066
+ yield _completion_stream_response (
1067
+ text = self .detokenize ([token ]).decode ("utf-8" , errors = "ignore" ), logprobs_or_none = logprobs_or_none
1068
+ )
1069
+ yield _completion_stream_response (
1070
+ text = self .detokenize (completion_tokens [returned_tokens :]).decode ("utf-8" , errors = "ignore" ), finish_reason = finish_reason
1071
+ )
1064
1072
if self .cache :
1065
1073
if self .verbose :
1066
1074
print ("Llama._create_completion: cache save" , file = sys .stderr )
@@ -1076,85 +1084,36 @@ def logit_bias_processor(
1076
1084
text_str = text .decode ("utf-8" , errors = "ignore" )
1077
1085
1078
1086
if echo :
1087
+ assert isinstance (prompt , str )
1079
1088
text_str = prompt + text_str
1080
1089
1081
1090
if suffix is not None :
1082
1091
text_str = text_str + suffix
1083
1092
1084
1093
logprobs_or_none : Optional [CompletionLogprobs ] = None
1085
1094
if logprobs is not None :
1095
+ # Remove leading BOS token
1096
+ all_tokens = prompt_tokens [1 :] + completion_tokens if echo else completion_tokens
1086
1097
text_offset = 0 if echo else len (prompt )
1087
1098
token_offset = 0 if echo else len (prompt_tokens [1 :])
1088
- text_offsets : List [int ] = []
1089
- token_logprobs : List [Optional [float ]] = []
1090
- tokens : List [str ] = []
1091
- top_logprobs : List [Optional [Dict [str , float ]]] = []
1092
-
1093
- if echo :
1094
- # Remove leading BOS token
1095
- all_tokens = prompt_tokens [1 :] + completion_tokens
1096
- else :
1097
- all_tokens = completion_tokens
1098
-
1099
1099
all_token_strs = [
1100
1100
self .detokenize ([token ]).decode ("utf-8" , errors = "ignore" )
1101
1101
for token in all_tokens
1102
1102
]
1103
1103
all_logprobs = [
1104
1104
Llama .logits_to_logprobs (row .tolist ()) for row in self ._scores
1105
1105
][token_offset :]
1106
- for token , token_str , logprobs_token in zip (
1107
- all_tokens , all_token_strs , all_logprobs
1108
- ):
1109
- if token == self .token_bos ():
1110
- continue
1111
- text_offsets .append (text_offset )
1112
- text_offset += len (token_str )
1113
- tokens .append (token_str )
1114
- sorted_logprobs = list (
1115
- sorted (
1116
- zip (logprobs_token , range (len (logprobs_token ))), reverse = True
1117
- )
1118
- )
1119
- token_logprobs .append (logprobs_token [int (token )])
1120
- top_logprob : Optional [Dict [str , float ]] = {
1121
- self .detokenize ([i ]).decode ("utf-8" , errors = "ignore" ): logprob
1122
- for logprob , i in sorted_logprobs [:logprobs ]
1123
- }
1124
- top_logprob .update ({token_str : logprobs_token [int (token )]})
1125
- top_logprobs .append (top_logprob )
1106
+ logprobs_or_none = _logprobs_or_none (
1107
+ all_tokens , all_token_strs , all_logprobs , text_offset
1108
+ )
1126
1109
# Weird idosincracy of the OpenAI API where
1127
1110
# token_logprobs and top_logprobs are null for
1128
1111
# the first token.
1129
1112
if echo and len (all_tokens ) > 0 :
1130
- token_logprobs [0 ] = None
1131
- top_logprobs [0 ] = None
1132
- logprobs_or_none = {
1133
- "tokens" : tokens ,
1134
- "text_offset" : text_offsets ,
1135
- "token_logprobs" : token_logprobs ,
1136
- "top_logprobs" : top_logprobs ,
1137
- }
1113
+ logprobs_or_none ["token_logprobs" ][0 ] = None
1114
+ logprobs_or_none ["top_logprobs" ][0 ] = None
1138
1115
1139
- yield {
1140
- "id" : completion_id ,
1141
- "object" : "text_completion" ,
1142
- "created" : created ,
1143
- "model" : model_name ,
1144
- "choices" : [
1145
- {
1146
- "text" : text_str ,
1147
- "index" : 0 ,
1148
- "logprobs" : logprobs_or_none ,
1149
- "finish_reason" : finish_reason ,
1150
- }
1151
- ],
1152
- "usage" : {
1153
- "prompt_tokens" : len (prompt_tokens ),
1154
- "completion_tokens" : len (completion_tokens ),
1155
- "total_tokens" : len (prompt_tokens ) + len (completion_tokens ),
1156
- },
1157
- }
1116
+ yield _completion_response (text_str , finish_reason , logprobs_or_none )
1158
1117
1159
1118
def create_completion (
1160
1119
self ,
0 commit comments