8000 Replace eval_logits and eval_tokens with numpy arrays · zinccat/llama-cpp-python@fe331ec · GitHub
[go: up one dir, main page]

Skip to content
8000

Commit fe331ec

Browse files
committed
Replace eval_logits and eval_tokens with numpy arrays
1 parent efb763b commit fe331ec

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

llama_cpp/llama.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,8 @@ def reset(self):
299299
"""Reset the model state."""
300300
self.eval_tokens.clear()
301301
self.eval_logits.clear()
302+
self._input_ids = np.array([], dtype=np.intc)
303+
self._scores = np.ndarray((0, self._n_vocab), dtype=np.single)
302304

303305
def eval(self, tokens: Sequence[int]):
304306
"""Evaluate a list of tokens.
@@ -310,7 +312,7 @@ def eval(self, tokens: Sequence[int]):
310312
n_ctx = self._n_ctx
311313
for i in range(0, len(tokens), self.n_batch):
312314
batch = tokens[i : min(len(tokens), i + self.n_batch)]
313-
n_past = min(n_ctx - len(batch), len(self.eval_tokens))
315+
n_past = min(n_ctx - len(batch), len(self._input_ids))
314316
n_tokens = len(batch)
315317
return_code = llama_cpp.llama_eval(
316318
ctx=self.ctx,
@@ -356,6 +358,7 @@ def _sample(
356358
):
357359
assert self.ctx is not None
358360
assert len(self.eval_logits) > 0
361+
assert self._scores.shape[0] > 0
359362
n_vocab = self._n_vocab
360363
n_ctx = self._n_ctx
361364
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
@@ -368,7 +371,7 @@ def _sample(
368371

369372
if logits_processor is not None:
370373
logits = np.array(
371-
logits_processor(list(self.eval_tokens), logits.tolist()),
374+
logits_processor(self._input_ids.tolist(), logits.tolist()),
372375
dtype=np.single,
373376
)
374377
self._scores[-1, :] = logits
@@ -498,8 +501,8 @@ def sample(
498501
"""
499502
assert self.ctx is not None
500503
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
501-
0, self.last_n_tokens_size - len(self.eval_tokens)
502-
) + list(self.eval_tokens)[-self.last_n_tokens_size :]
504+
0, self.last_n_tokens_size - len(self._input_ids)
505+
) + self._input_ids[-self.last_n_tokens_size :].tolist()
503506
return self._sample(
504507
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
505508
*last_n_tokens_data
@@ -557,9 +560,9 @@ def generate(
557560
"""
558561
assert self.ctx is not None
559562

560-
if reset and len(self.eval_tokens) > 0:
563+
if reset and len(self._input_ids) > 0:
561564
longest_prefix = 0
562-
for a, b in zip(self.eval_tokens, tokens[:-1]):
565+
for a, b in zip(self._input_ids, tokens[:-1]):
563566
if a == b:
564567
longest_prefix += 1
565568
else:
@@ -569,6 +572,8 @@ def generate(
569572
print("Llama.generate: prefix-match hit", file=sys.stderr)
570573
reset = False
571574
tokens = tokens[longest_prefix:]
575+
self._input_ids = self._input_ids[:longest_prefix]
576+
self._scores = self._scores[:longest_prefix, :]
572577
for _ in range(len(self.eval_tokens) - longest_prefix):
573578
self.eval_tokens.pop()
574579
try:
@@ -595,7 +600,7 @@ def generate(
595600
logits_processor=logits_processor,
596601
)
597602
if stopping_criteria is not None and stopping_criteria(
598-
list(self.eval_tokens), self.eval_logits[-1]
603+
self._input_ids.tolist(), self._scores[-1, :].tolist()
599604
):
600605
return
601606
tokens_or_none = yield token
@@ -820,7 +825,7 @@ def _create_completion(
820825
self.detokenize(completion_tokens[:returned_tokens])
821826
)
822827
token_offset = len(prompt_tokens) + returned_tokens
823-
logits = self.eval_logits[token_offset - 1]
828+
logits = self._scores[token_offset - 1, :].tolist()
824829
current_logprobs = Llama.logits_to_logprobs(logits)
825830
sorted_logprobs = list(
826831
sorted(
@@ -869,7 +874,7 @@ def _create_completion(
869874
break
870875

871876
if stopping_criteria is not None and stopping_criteria(
872-
list(self.eval_tokens), self.eval_logits[-1]
877+
self._input_ids.tolist(), self._scores[-1, :].tolist()
873878
):
874879
text = self.detokenize(completion_tokens)
875880
finish_reason = "stop"
@@ -899,7 +904,7 @@ def _create_completion(
899904
self.detokenize(completion_tokens[:returned_tokens])
900905
)
901906
token_offset = len(prompt_tokens) + returned_tokens - 1
902-
logits = self.eval_logits[token_offset]
907+
logits = self._scores[token_offset, :].tolist()
903908
current_logprobs = Llama.logits_to_logprobs(logits)
904909
sorted_logprobs = list(
905910
sorted(
@@ -1001,8 +1006,7 @@ def _create_completion(
10011006
for token in all_tokens
10021007
]
10031008
all_logprobs = [
1004-
Llama.logits_to_logprobs(list(map(float, row)))
1005-
for row in self.eval_logits
1009+
Llama.logits_to_logprobs(row.tolist()) for row in self._scores
10061010
][token_offset:]
10071011
for token, token_str, logprobs_token in zip(
10081012
all_tokens, all_token_strs, all_logprobs

0 commit comments

Comments
 (0)
0