@@ -480,7 +480,7 @@ def detokenize(
480
480
Returns:
481
481
The detokenized string.
482
482
"""
483
- return self .tokenizer_ .detokenize (tokens , prev_tokens )
483
+ return self .tokenizer_ .detokenize (tokens , prev_tokens = prev_tokens )
484
484
485
485
def set_cache (self , cache : Optional [BaseLlamaCache ]):
486
486
"""Set the cache.
@@ -1016,13 +1016,13 @@ def logit_bias_processor(
1016
1016
grammar = grammar ,
1017
1017
):
1018
1018
if token == self ._token_eos :
1019
- text = self .detokenize (completion_tokens )
1019
+ text = self .detokenize (completion_tokens , prev_tokens = prompt_tokens )
1020
1020
finish_reason = "stop"
1021
1021
break
1022
1022
1023
1023
completion_tokens .append (token )
1024
1024
1025
- all_text = self .detokenize (completion_tokens )
1025
+ all_text = self .detokenize (completion_tokens , prev_tokens = prompt_tokens )
1026
1026
1027
1027
# Contains multi-byte UTF8
1028
1028
for k , char in enumerate (all_text [- 3 :]):
@@ -1046,7 +1046,7 @@ def logit_bias_processor(
1046
1046
1047
1047
if stream :
1048
1048
remaining_tokens = completion_tokens [returned_tokens :]
1049
- remaining_text = self .detokenize (remaining_tokens )
1049
+ remaining_text = self .detokenize (remaining_tokens , prev_tokens = prompt_tokens + completion_tokens [: returned_tokens ] )
1050
1050
remaining_length = len (remaining_text )
1051
1051
1052
1052
# We want to avoid yielding any characters from
@@ -1068,17 +1068,17 @@ def logit_bias_processor(
1068
1068
for token in remaining_tokens :
1069
1069
if token == self .token_bos ():
1070
1070
continue
1071
- token_end_position += len (self .detokenize ([token ]))
1071
+ token_end_position += len (self .detokenize ([token ], prev_tokens = prompt_tokens + completion_tokens [: returned_tokens ] ))
1072
1072
# Check if stop sequence is in the token
1073
1073
if token_end_position > (
1074
1074
remaining_length - first_stop_position
1075
1075
):
1076
1076
break
1077
- token_str = self .detokenize ([token ]).decode (
1077
+ token_str = self .detokenize ([token ], prev_tokens = prompt_tokens + completion_tokens [: returned_tokens ] ).decode (
1078
1078
"utf-8" , errors = "ignore"
1079
1079
)
1080
1080
text_offset = len (prompt ) + len (
1081
- self .detokenize (completion_tokens [:returned_tokens ]).decode (
1081
+ self .detokenize (completion_tokens [:returned_tokens ], prev_tokens = prompt_tokens + completion_tokens [: returned_tokens ] ).decode (
1082
1082
"utf-8" , errors = "ignore"
1083
1083
)
1084
1084
)
@@ -1100,7 +1100,7 @@ def logit_bias_processor(
1100
1100
top_logprob .update ({token_str : current_logprobs [int (token )]})
1101
1101
logprobs_or_none = {
1102
1102
"tokens" : [
1103
- self .detokenize ([token ]).decode (
1103
+ self .detokenize ([token ], prev_tokens = prompt_tokens + completion_tokens [: returned_tokens ] ).decode (
1104
1104
"utf-8" , errors = "ignore"
1105
1105
)
1106
1106
],
@@ -1116,7 +1116,7 @@ def logit_bias_processor(
1116
1116
"model" : model_name ,
1117
1117
"choices" : [
1118
1118
{
1119
- "text" : self .detokenize ([token ]).decode (
1119
+ "text" : self .detokenize ([token ], prev_tokens = prompt_tokens + completion_tokens [: returned_tokens ] ).decode (
1120
1120
"utf-8" , errors = "ignore"
1121
1121
),
1122
1122
"index" : 0 ,
@@ -1130,7 +1130,7 @@ def logit_bias_processor(
1130
1130
decode_success = False
1131
1131
for i in range (1 , len (remaining_tokens ) + 1 ):
1132
1132
try :
1133
- bs = self .detokenize (remaining_tokens [:i ])
1133
+ bs = self .detokenize (remaining_tokens [:i ], prev_tokens = prompt_tokens + completion_tokens [: returned_tokens ] )
1134
1134
ts = bs .decode ("utf-8" )
1135
1135
decode_success = True
1136
1136
break
@@ -1165,22 +1165,22 @@ def logit_bias_processor(
1165
1165
}
1166
1166
1167
1167
if len (completion_tokens ) >= max_tokens :
1168
- text = self .detokenize (completion_tokens )
1168
+ text = self .detokenize (completion_tokens , prev_tokens = prompt_tokens )
1169
1169
finish_reason = "length"
1170
1170
break
1171
1171
1172
1172
if stopping_criteria is not None and stopping_criteria (
1173
1173
self ._input_ids , self ._scores [- 1 , :]
1174
1174
):
1175
- text = self .detokenize (completion_tokens )
1175
+ text = self .detokenize (completion_tokens , prev_tokens = prompt_tokens )
1176
1176
finish_reason = "stop"
1177
1177
1178
1178
if self .verbose :
1179
1179
self ._ctx .print_timings ()
1180
1180
1181
1181
if stream :
1182
1182
remaining_tokens = completion_tokens [returned_tokens :]
1183
- all_text = self .detokenize (remaining_tokens )
1183
+ all_text = self .detokenize (remaining_tokens , prev_tokens = prompt_tokens + completion_tokens [: returned_tokens ] )
1184
1184
any_stop = [s for s in stop_sequences if s in all_text ]
1185
1185
if len (any_stop ) > 0 :
1186
1186
end = min (all_text .index (stop ) for stop in any_stop )
@@ -1189,7 +1189,7 @@ def logit_bias_processor(
1189
1189
1190
1190
token_end_position = 0
1191
1191
for token in remaining_tokens :
1192
- token_end_position += len (self .detokenize ([token ]))
1192
+ token_end_position += len (self .detokenize ([token ], prev_tokens = prompt_tokens + completion_tokens [: returned_tokens ] ))
1193
1193
1194
1194
logprobs_or_none : Optional [CompletionLogprobs ] = None
1195
1195
if logprobs is not None :
@@ -1199,7 +1199,7 @@ def logit_bias_processor(
1199
1199
"utf-8" , errors = "ignore"
1200
1200
)
1201
1201
text_offset = len (prompt ) + len (
1202
- self .detokenize (completion_tokens [:returned_tokens ])
1202
+ self .detokenize (completion_tokens [:returned_tokens ], prev_tokens = prompt_tokens + completion_tokens [: returned_tokens ] )
1203
1203
)
1204
1204
token_offset = len (prompt_tokens ) + returned_tokens - 1
1205
1205
logits = self ._scores [token_offset , :]
@@ -1313,8 +1313,8 @@ def logit_bias_processor(
1313
1313
all_tokens = completion_tokens
1314
1314
1315
1315
all_token_strs = [
1316
- self .detokenize ([token ]).decode ("utf-8" , errors = "ignore" )
1317
- for token in all_tokens
1316
+ self .detokenize ([token ], prev_tokens = all_tokens [: i ] ).decode ("utf-8" , errors = "ignore" )
1317
+ for i , token in enumerate ( all_tokens )
1318
1318
]
1319
1319
all_logprobs = Llama .logits_to_logprobs (self ._scores )[token_offset :]
1320
1320
# TODO: may be able to change this loop to use np.take_along_dim
@@ -1339,7 +1339,7 @@ def logit_bias_processor(
1339
1339
)
1340
1340
token_logprobs .append (logprobs_token [int (token )])
1341
1341
top_logprob : Optional [Dict [str , float ]] = {
1342
- self .detokenize ([i ]).decode ("utf-8" , errors = "ignore" ): logprob
1342
+ self .detokenize ([i ], prev_tokens = all_tokens [: idx ] ).decode ("utf-8" , errors = "ignore" ): logprob
1343
1343
for logprob , i in sorted_logprobs [:logprobs ]
1344
1344
}
1345
1345
top_logprob .update ({token_str : logprobs_token [int (token )]})
@@ -1594,6 +1594,8 @@ def create_chat_completion(
1594
1594
logits_processor : Optional [LogitsProcessorList ] = None ,
1595
1595
grammar : Optional [LlamaGrammar ] = None ,
1596
1596
logit_bias : Optional [Dict [str , float ]] = None ,
1597
+ logprobs : Optional [bool ] = None ,
1598
+ top_logprobs : Optional [int ] = None ,
1597
1599
) -> Union [
1598
1600
CreateChatCompletionResponse , Iterator [CreateChatCompletionStreamResponse ]
1599
1601
]:
0 commit comments