8000 fix: revert _create_completions. · notwa/llama-cpp-python@e16f06e · GitHub
[go: up one dir, main page]

Skip to content

Commit e16f06e

Browse files
committed
fix: revert _create_completions.
1 parent dfc1b17 commit e16f06e

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

llama_cpp/llama.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -948,8 +948,7 @@ def logit_bias_processor(
948948

949949
if stream:
950950
remaining_tokens = completion_tokens[returned_tokens:]
951-
prev_tokens = completion_tokens[:returned_tokens]
952-
remaining_text = self.detokenize(completion_tokens, prev_tokens)
951+
remaining_text = self.detokenize(remaining_tokens)
953952
remaining_length = len(remaining_text)
954953

955954
# We want to avoid yielding any characters from
@@ -971,13 +970,13 @@ def logit_bias_processor(
971970
for token in remaining_tokens:
972971
if token == self.token_bos():
973972
continue
974-
token_end_position += len(remaining_text)
973+
token_end_position += len(self.detokenize([token]))
975974
# Check if stop sequence is in the token
976975
if token_end_position > (
977976
remaining_length - first_stop_position
978977
):
979978
break
980-
token_str = remaining_text.decode(
979+
token_str = self.detokenize([token]).decode(
981980
"utf-8", errors="ignore"
982981
)
983982
text_offset = len(prompt) + len(
@@ -1002,7 +1001,11 @@ def logit_bias_processor(
10021001
}
10031002
top_logprob.update({token_str: current_logprobs[int(token)]})
10041003
logprobs_or_none = {
1005-
"tokens": [token_str],
1004+
"tokens": [
1005+
self.detokenize([token]).decode(
1006+
"utf-8", errors="ignore"
1007+
)
1008+
],
10061009
"text_offset": [text_offset],
10071010
"token_logprobs": [current_logprobs[int(token)]],
10081011
"top_logprobs": [top_logprob],
@@ -1015,7 +1018,9 @@ def logit_bias_processor(
10151018
"model": model_name,
10161019
"choices": [
10171020
{
1018-
"text": token_str,
1021+
"text": self.detokenize([token]).decode(
1022+
"utf-8", errors="ignore"
1023+
),
10191024
"index": 0,
10201025
"logprobs": logprobs_or_none,
10211026
"finish_reason": None,
@@ -1027,7 +1032,7 @@ def logit_bias_processor(
10271032
decode_success = False
10281033
for i in range(1, len(remaining_tokens) + 1):
10291034
try:
1030-
bs = remaining_text
1035+
bs = self.detokenize(remaining_tokens[:i])
10311036
ts = bs.decode("utf-8")
10321037
decode_success = True
10331038
break
@@ -1063,7 +1068,6 @@ def logit_bias_processor(
10631068

10641069
if len(completion_tokens) >= max_to 51E3 kens:
10651070
text = self.detokenize(completion_tokens)
1066-
10671071
finish_reason = "length"
10681072
break
10691073

0 commit comments

Comments
 (0)
0