10000 fix: LlamaHFTokenizer now receives pre_tokens · coderonion/llama-cpp-python@47bad30 · GitHub
[go: up one dir, main page]

Skip to content

Commit 47bad30

Browse files
committed
fix: LlamaHFTokenizer now receives pre_tokens
1 parent ded5d62 commit 47bad30

File tree

2 files changed

+33
-23
lines changed

2 files changed

+33
-23
lines changed

llama_cpp/llama.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ def detokenize(
480480
Returns:
481481
The detokenized string.
482482
"""
483-
return self.tokenizer_.detokenize(tokens, prev_tokens)
483+
return self.tokenizer_.detokenize(tokens, prev_tokens=prev_tokens)
484484

485485
def set_cache(self, cache: Optional[BaseLlamaCache]):
486486
"""Set the cache.
@@ -1016,13 +1016,13 @@ def logit_bias_processor(
10161016
grammar=grammar,
10171017
):
10181018
if token == self._token_eos:
1019-
text = self.detokenize(completion_tokens)
1019+
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
10201020
finish_reason = "stop"
10211021
break
10221022

10231023
completion_tokens.append(token)
10241024

1025-
all_text = self.detokenize(completion_tokens)
1025+
all_text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
10261026

10271027
# Contains multi-byte UTF8
10281028
for k, char in enumerate(all_text[-3:]):
@@ -1046,7 +1046,7 @@ def logit_bias_processor(
10461046

10471047
if stream:
10481048
remaining_tokens = completion_tokens[returned_tokens:]
1049-
remaining_text = self.detokenize(remaining_tokens)
1049+
remaining_text = self.detokenize(remaining_tokens, prev_tokens=prompt_tokens + completion_tokens[:returned_tokens])
10501050
remaining_length = len(remaining_text)
10511051

10521052
# We want to avoid yielding any characters from
@@ -1068,17 +1068,17 @@ def logit_bias_processor(
10681068
for token in remaining_tokens:
10691069
if token == self.token_bos():
10701070
continue
1071-
token_end_position += len(self.detokenize([token]))
1071+
token_end_position += len(self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]))
10721072
# Check if stop sequence is in the token
10731073
if token_end_position > (
10741074
remaining_length - first_stop_position
10751075
):
10761076
break
1077-
token_str = self.detokenize([token]).decode(
1077+
token_str = self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode(
10781078
"utf-8", errors="ignore"
10791079
)
10801080
text_offset = len(prompt) + len(
1081-
self.detokenize(completion_tokens[:returned_tokens]).decode(
1081+
self.detokenize(completion_tokens[:returned_tokens], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode(
10821082
"utf-8", errors="ignore"
10831083
)
10841084
)
@@ -1100,7 +1100,7 @@ def logit_bias_processor(
11001100
top_logprob.update({token_str: current_logprobs[int(token)]})
11011101
logprobs_or_none = {
11021102
"tokens": [
1103-
self.detokenize([token]).decode(
1103+
self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode(
11041104
"utf-8", errors="ignore"
11051105
)
11061106
],
@@ -1116,7 +1116,7 @@ def logit_bias_processor(
11161116
"model": model_name,
11171117
"choices": [
11181118
{
1119-
"text": self.detokenize([token]).decode(
1119+
"text": self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode(
11201120
"utf-8", errors="ignore"
11211121
),
11221122
"index": 0,
@@ -1130,7 +1130,7 @@ def logit_bias_processor(
11301130
decode_success = False
11311131
for i in range(1, len(remaining_tokens) + 1):
11321132
try:
1133-
bs = self.detokenize(remaining_tokens[:i])
1133+
bs = self.detokenize(remaining_tokens[:i], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens])
11341134
ts = bs.decode("utf-8")
11351135
decode_success = True
11361136
break
@@ -1165,22 +1165,22 @@ def logit_bias_processor(
11651165
}
11661166

11671167
if len(completion_tokens) >= max_tokens:
1168-
text = self.detokenize(completion_tokens)
1168+
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
11691169
finish_reason = "length"
11701170
break
11711171

11721172
if stopping_criteria is not None and stopping_criteria(
11731173
self._input_ids, self._scores[-1, :]
11741174
):
1175-
text = self.detokenize(completion_tokens)
1175+
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
11761176
finish_reason = "stop"
11771177

11781178
if self.verbose:
11791179
self._ctx.print_timings()
11801180

11811181
if stream:
11821182
remaining_tokens = completion_tokens[returned_tokens:]
1183-
all_text = self.detokenize(remaining_tokens)
1183+
all_text = self.detokenize(remaining_tokens, prev_tokens=prompt_tokens + completion_tokens[:returned_tokens])
11841184
any_stop = [s for s in stop_sequences if s in all_text]
11851185
if len(any_stop) > 0:
11861186
end = min(all_text.index(stop) for stop in any_stop)
@@ -1189,7 +1189,7 @@ def logit_bias_processor(
11891189

11901190
token_end_position = 0
11911191
for token in remaining_tokens:
1192-
token_end_position += len(self.detokenize([token]))
1192+
token_end_position += len(self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]))
11931193

11941194
logprobs_or_none: Optional[CompletionLogprobs] = None
11951195
if logprobs is not None:
@@ -1199,7 +1199,7 @@ def logit_bias_processor(
11991199
"utf-8", errors="ignore"
12001200
)
12011201
text_offset = len(prompt) + len(
1202-
self.detokenize(completion_tokens[:returned_tokens])
1202+
self.detokenize(completion_tokens[:returned_tokens], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens])
12031203
)
12041204
token_offset = len(prompt_tokens) + returned_tokens - 1
12051205
logits = self._scores[token_offset, :]
@@ -1313,8 +1313,8 @@ def logit_bias_processor(
13131313
all_tokens = completion_tokens
13141314

13151315
all_token_strs = [
1316-
self.detokenize([token]).decode("utf-8", errors="ignore")
1317-
for token in all_tokens
1316+
self.detokenize([token], prev_tokens=all_tokens[:i]).decode("utf-8", errors="ignore")
1317+
for i, token in enumerate(all_tokens)
13181318
]
13191319
all_logprobs = Llama.logits_to_logprobs(self._scores)[token_offset:]
13201320
# TODO: may be able to change this loop to use np.take_along_dim
@@ -1339,7 +1339,7 @@ def logit_bias_processor(
13391339
)
13401340
token_logprobs.append(logprobs_token[int(token)])
13411341
top_logprob: Optional[Dict[str, float]] = {
1342-
self.detokenize([i]).decode("utf-8", errors="ignore"): logprob
1342+
self.detokenize([i], prev_tokens=all_tokens[:idx]).decode("utf-8", errors="ignore"): logprob
13431343
for logprob, i in sorted_logprobs[:logprobs]
13441344
}
13451345
top_logprob.update({token_str: logprobs_token[int(token)]})
@@ -1594,6 +1594,8 @@ def create_chat_completion(
15941594
logits_processor: Optional[LogitsProcessorList] = None,
15951595
grammar: Optional[LlamaGrammar] = None,
15961596
logit_bias: Optional[Dict[str, float]] = None,
1597+
logprobs: Optional[bool] = None,
1598+
top_logprobs: Optional[int] = None,
15971599
) -> Union[
15981600
CreateChatCompletionResponse, Iterator[CreateChatCompletionStreamResponse]
15991601
]:

llama_cpp/llama_tokenizer.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,23 @@ class BaseLlamaTokenizer(abc.ABC):
1616
def tokenize(
1717
self, text: bytes, add_bos: bool = True, special: bool = True
1818
) -> List[int]:
19+
"""Tokenize the text into tokens.
20+
21+
Args:
22+
text: The text to tokenize.
23+
add_bos: Whether to add a beginning of sequence token.
24+
special: Whether to tokenize text literally or as special tokens."""
1925
raise NotImplementedError
2026

2127
@abc.abstractmethod
2228
def detokenize(
2329
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
2430
) -> bytes:
31+
"""Detokenize the tokens into text.
32+
33+
Args:
34+
tokens: The tokens to detokenize.
35+
prev_tokens: If tokens is a continuation of a previous sequence, the previous tokens."""
2536
raise NotImplementedError
2637

2738

@@ -37,10 +48,7 @@ def tokenize(
3748
def detokenize(
3849
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
3950
) -> bytes:
40-
if prev_tokens is not None:
41-
return self._model.detokenize(tokens[len(prev_tokens) :])
42-
else:
43-
return self._model.detokenize(tokens)
51+
return self._model.detokenize(tokens)
4452

4553
def encode(
4654
self, text: str, add_bos: bool = True, special: bool = True
@@ -72,7 +80,7 @@ def detokenize(
7280
self, tokens: List[int], prev_tokens: Optional[List[int]] = None
7381
) -> bytes:
7482
if prev_tokens is not None:
75-
text = self.hf_tokenizer.decode(tokens).encode("utf-8", errors="ignore")
83+
text = self.hf_tokenizer.decode(prev_tokens + tokens).encode("utf-8", errors="ignore")
7684
prev_text = self.hf_tokenizer.decode(prev_tokens).encode(
7785
"utf-8", errors="ignore"
7886
)

0 commit comments

Comments
 (0)
0