8000 Refactor internal state for Llama class · lidanger/llama-cpp-python@86f8e5a · GitHub
[go: up one dir, main page]

Skip to content

Commit 86f8e5a

Browse files
committed
Refactor internal state for Llama class
1 parent 02cf881 commit 86f8e5a

File tree

1 file changed

+23
-40
lines changed

1 file changed

+23
-40
lines changed

llama_cpp/llama.py

Lines changed: 23 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -84,16 +84,9 @@ def __init__(
8484
self.params.embedding = embedding
8585

8686
self.last_n_tokens_size = last_n_tokens_size
87-
self.last_n_tokens_data = deque(
88-
[llama_cpp.llama_token(0)] * self.last_n_tokens_size,
89-
maxlen=self.last_n_tokens_size,
90-
)
91-
self.tokens_consumed = 0
92-
self.tokens: List[llama_cpp.llama_token] = []
9387
self.n_batch = min(n_ctx, n_batch)
94-
self.n_tokens = 0
95-
self.n_past = 0
96-
self.all_logits: List[List[float]] = [] # TODO: Use an array instead of a list.
88+
self.eval_tokens: deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
89+
self.eval_logits: deque[List[float]] = deque(maxlen=n_ctx)
9790

9891
### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
9992
### saving and restoring state, this allows us to continue a completion if the last
@@ -181,14 +174,8 @@ def set_cache(self, cache: Optional[LlamaCache]):
181174

182175
def reset(self):
183176
"""Reset the model state."""
184-
self.last_n_tokens_data.extend(
185-
[llama_cpp.llama_token(0)] * self.last_n_tokens_size
186-
)
187-
self.tokens_consumed = 0
188-
self.tokens.clear()
189-
self.n_tokens = 0
190-
self.n_past = 0
191-
self.all_logits.clear()
177+
self.eval_tokens.clear()
178+
self.eval_logits.clear()
192179

193180
def eval(self, tokens: Sequence[llama_cpp.llama_token]):
194181
"""Evaluate a list of tokens.
@@ -200,32 +187,25 @@ def eval(self, tokens: Sequence[llama_cpp.llama_token]):
200187
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
201188
for i in range(0, len(tokens), self.n_batch):
202189
batch = tokens[i : min(len(tokens), i + self.n_batch)]
203-
self.n_past = min(n_ctx - len(batch), self.tokens_consumed)
204-
self.n_tokens = len(batch)
190+
n_past = min(n_ctx - len(batch), len(self.eval_tokens))
191+
n_tokens = len(batch)
205192
return_code = llama_cpp.llama_eval(
206193
ctx=self.ctx,
207194
tokens=(llama_cpp.llama_token * len(batch))(*batch),
208-
n_tokens=llama_cpp.c_int(self.n_tokens),
209-
n_past=llama_cpp.c_int(self.n_past),
195+
n_tokens=llama_cpp.c_int(n_tokens),
196+
n_past=llama_cpp.c_int(n_past),
210197
n_threads=llama_cpp.c_int(self.n_threads),
211198
)
212199
if int(return_code) != 0:
213200
raise RuntimeError(f"llama_eval returned {return_code}")
214-
self.tokens.extend(batch)
215-
self.last_n_tokens_data.extend(batch)
216-
self.tokens_consumed += len(batch)
201+
self.eval_tokens.extend(batch)
217202
if self.params.logits_all:
218-
self.all_logits.extend(self._logits())
219-
220-
def _logits(self) -> List[List[float]]:
221-
"""Return the logits from the last call to llama_eval."""
222-
assert self.ctx is not None
223-
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
224-
cols = int(n_vocab)
225-
rows = self.n_tokens if self.params.logits_all else 1
226-
logits_view = llama_cpp.llama_get_logits(self.ctx)
227-
logits = [[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)]
228-
return logits
203+
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
204+
cols = int(n_vocab)
205+
rows = n_tokens
206+
logits_view = llama_cpp.llama_get_logits(self.ctx)
207+
logits = [[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)]
208+
self.eval_logits.extend(logits)
229209

230210
def sample(
231211
self,
@@ -246,10 +226,13 @@ def sample(
246226
The sampled token.
247227
"""
248228
assert self.ctx is not None
229+
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
230+
0, self.last_n_tokens_size - len(self.eval_tokens)
231+
) + list(self.eval_tokens)[-self.last_n_tokens_size :]
249232
return llama_cpp.llama_sample_top_p_top_k(
250233
ctx=self.ctx,
251234
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
252-
*self.last_n_tokens_data
235+
*last_n_tokens_data
253236
),
254237
last_n_tokens_size=llama_cpp.c_int(self.last_n_tokens_size),
255238
top_k=llama_cpp.c_int(top_k),
@@ -293,13 +276,13 @@ def generate(
293276
if (
294277
reset
295278
and self._cache
296-
and len(self.tokens) > 0
297-
and self.tokens == tokens[: len(self.tokens)]
279+
and len(self.eval_tokens) > 0
280+
and self.eval_tokens == tokens[: len(self.eval_tokens)]
298281
):
299282
if self.verbose:
300283
print("generate cache hit", file=sys.stderr)
301284
reset = False
302-
tokens = tokens[len(self.tokens) :]
285+
tokens = tokens[len(self.eval_tokens) :]
303286
###
304287
if reset:
305288
self.reset()
@@ -537,7 +520,7 @@ def _create_completion(
537520
]
538521
all_logprobs = [
539522
[Llama.logit_to_logprob(logit) for logit in row]
540-
for row in self.all_logits
523+
for row in self.eval_logits
541524
]
542525
for token, token_str, logprobs_token in zip(
543526
all_tokens, all_token_strs, all_logprobs

0 commit comments

Comments
 (0)
0