8000 Refactor Llama class and add tokenize / detokenize methods Closes #3 · coderonion/llama-cpp-python@1c823f6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1c823f6

Browse files
committed
Refactor Llama class and add tokenize / detokenize methods Closes abetlen#3
1 parent 6dbff76 commit 1c823f6

File tree

1 file changed

+84
-57
lines changed

1 file changed

+84
-57
lines changed

llama_cpp/llama.py

Lines changed: 84 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import time
44
import multiprocessing
55
from typing import List, Optional
6+
from collections import deque
67

78
from . import llama_cpp
89

@@ -46,9 +47,6 @@ def __init__(
4647
"""
4748
self.model_path = model_path
4849

49-
self.last_n = 64
50-
self.max_chunk_size = 32
51-
5250
self.params = llama_cpp.llama_context_default_params()
5351
self.params.n_ctx = n_ctx
5452
self.params.n_parts = n_parts
@@ -59,9 +57,10 @@ def __init__(
5957
self.params.use_mlock = use_mlock
6058
self.params.embedding = embedding
6159

62-
self.n_threads = n_threads or multiprocessing.cpu_count()
60+
self.last_n = 64
61+
self.max_chunk_size = n_ctx
6362

64-
self.tokens = (llama_cpp.llama_token * self.params.n_ctx)()
63+
self.n_threads = n_threads or multiprocessing.cpu_count()
6564

6665
if not os.path.exists(model_path):
6766
raise ValueError(f"Model path does not exist: {model_path}")
@@ -70,6 +69,65 @@ def __init__(
7069
self.model_path.encode("utf-8"), self.params
7170
)
7271

72+
def tokenize(self, text: bytes) -> List[int]:
73+
"""Tokenize a string.
74+
75+
Args:
76+
text: The utf-8 encoded string to tokenize.
77+
78+
Returns:
79+
A list of tokens.
80+
"""
81+
n_ctx = llama_cpp.llama_n_ctx(self.ctx)
82+
tokens = (llama_cpp.llama_token * n_ctx)()
83+
n_tokens = llama_cpp.llama_tokenize(
84+
self.ctx,
85+
text,
86+
tokens,
87+
n_ctx,
88+
True,
89+
)
90+
if n_tokens < 0:
91+
raise RuntimeError(f"Failed to tokenize: text=\"{text}\" n_tokens={n_tokens}")
92+
return list(tokens[:n_tokens])
93+
94+
def detokenize(self, tokens: List[int]) -> bytes:
95+
"""Detokenize a list of tokens.
96+
97+
Args:
98+
tokens: The list of tokens to detokenize.
99+
100+
Returns:
101+
The detokenized string.
102+
"""
103+
output = b""
104+
for token in tokens:
105+
output += llama_cpp.llama_token_to_str(self.ctx, token)
106+
return output
107+
108+
109+
def _eval(self, tokens: List[int], n_past):
110+
rc = llama_cpp.llama_eval(
111+
self.ctx,
112+
(llama_cpp.llama_token * len(tokens))(*tokens),
113+
len(tokens),
114+
n_past,
115+
self.n_threads,
116+
)
117+
if rc != 0:
118+
raise RuntimeError(f"Failed to evaluate: {rc}")
119+
120+
def _sample(self, last_n_tokens, top_p, top_k, temp, repeat_penalty):
121+
return llama_cpp.llama_sample_top_p_top_k(
122+
self.ctx,
123+
(llama_cpp.llama_token * len(last_n_tokens))(*last_n_tokens),
124+
len(last_n_tokens),
125+
top_k=top_k,
126+
top_p=top_p,
127+
temp=temp,
128+
repeat_penalty=repeat_penalty,
129+
)
130+
73131
def __call__(
74132
self,
75133
prompt: str,
@@ -106,61 +164,38 @@ def __call__(
106164
"""
107165
text = b""
108166
finish_reason = "length"
109-
completion_tokens = 0
167+
completion_tokens = []
168+
last_n_tokens = deque([0] * self.last_n, maxlen=self.last_n)
110169

111-
if stop is not None:
112-
stop = [s.encode("utf-8") for s in stop]
113-
114-
prompt_tokens = llama_cpp.llama_tokenize(
115-
self.ctx,
116-
prompt.encode("utf-8"),
117-
self.tokens,
118-
llama_cpp.llama_n_ctx(self.ctx),
119-
True,
120-
)
121-
if prompt_tokens < 0:
122-
raise RuntimeError(f"Failed to tokenize prompt: {prompt_tokens}")
170+
prompt_tokens = self.tokenize(prompt.encode("utf-8"))
123171

124-
if prompt_tokens + max_tokens > self.params.n_ctx:
172+
if len(prompt_tokens) + max_tokens > llama_cpp.llama_n_ctx(self.ctx):
125173
raise ValueError(
126174
f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
127175
)
128176

129177
# Process prompt in chunks to avoid running out of memory
130-
for i in range(0, prompt_tokens, self.max_chunk_size):
131-
chunk = self.tokens[i : min(prompt_tokens, i + self.max_chunk_size)]
132-
rc = llama_cpp.llama_eval(
133-
self.ctx,
134-
(llama_cpp.llama_token * len(chunk))(*chunk),
135-
len(chunk),
136-
max(0, i - 1),
137-
self.n_threads,
138-
)
139-
if rc != 0:
140-
raise RuntimeError(f"Failed to evaluate prompt: {rc}")
178+
for i in range(0, len(prompt_tokens), self.max_chunk_size):
179+
chunk = prompt_tokens[i : min(len(prompt_tokens), i + self.max_chunk_size)]
180+
self._eval(chunk, n_past=i)
141181

142-
for i in range(max_tokens):
143-
tokens_seen = prompt_tokens + completion_tokens
144-
last_n_tokens = [0] * max(0, self.last_n - tokens_seen) + [
145-
self.tokens[j]
146-
for j in range(max(tokens_seen - self.last_n, 0), tokens_seen)
147-
]
182+
if stop is not None:
183+
stop = [s.encode("utf-8") for s in stop]
148184

149-
token = llama_cpp.llama_sample_top_p_top_k(
150-
self.ctx,
151-
(llama_cpp.llama_token * len(last_n_tokens))(*last_n_tokens),
152-
len(last_n_tokens),
153-
top_k=top_k,
185+
for i in range(max_tokens):
186+
token = self._sample(
187+
last_n_tokens,
154188
top_p=top_p,
189+
top_k=top_k,
155190
temp=temperature,
156-
repeat_penalty=repeat_penalty,
191+
repeat_penalty=repeat_penalty
157192
)
158193
if token == llama_cpp.llama_token_eos():
159194
finish_reason = "stop"
160195
break
161-
text += llama_cpp.llama_token_to_str(self.ctx, token)
162-
self.tokens[prompt_tokens + i] = token
163-
completion_tokens += 1
196+
text += self.detokenize([token])
197+
last_n_tokens.append(token)
198+
completion_tokens.append(token)
164199

165200
any_stop = [s for s in stop if s in text]
166201
if len(any_stop) > 0:
@@ -169,15 +204,7 @@ def __call__(
169204
finish_reason = "stop"
170205
break
171206

172-
rc = llama_cpp.llama_eval(
173-
self.ctx,
174-
(llama_cpp.llama_token * 1)(self.tokens[prompt_tokens + i]),
175-
1,
176-
prompt_tokens + completion_tokens,
177-
self.n_threads,
178-
)
179-
if rc != 0:
180-
raise RuntimeError(f"Failed to evaluate next token: {rc}")
207+
self._eval([token], len(prompt_tokens) + len(completion_tokens))
181208

182209
text = text.decode("utf-8")
183210

@@ -206,9 +233,9 @@ def __call__(
206233
}
207234
],
208235
"usage": {
209-
"prompt_tokens": prompt_tokens,
210-
"completion_tokens": completion_tokens,
211-
"total_tokens": prompt_tokens + completion_tokens,
236+
"prompt_tokens": len(prompt_tokens),
237+
"completion_tokens": len(completion_tokens),
238+
"total_tokens": len(prompt_tokens) + len(completion_tokens),
212239
},
213240
}
214241

0 commit comments

Comments
 (0)
0