@@ -299,6 +299,8 @@ def reset(self):
299
299
"""Reset the model state."""
300
300
self .eval_tokens .clear ()
301
301
self .eval_logits .clear ()
302
+ self ._input_ids = np .array ([], dtype = np .intc )
303
+ self ._scores = np .ndarray ((0 , self ._n_vocab ), dtype = np .single )
302
304
303
305
def eval (self , tokens : Sequence [int ]):
304
306
"""Evaluate a list of tokens.
@@ -310,7 +312,7 @@ def eval(self, tokens: Sequence[int]):
310
312
n_ctx = self ._n_ctx
311
313
for i in range (0 , len (tokens ), self .n_batch ):
312
314
batch = tokens [i : min (len (tokens ), i + self .n_batch )]
313
- n_past = min (n_ctx - len (batch ), len (self .eval_tokens ))
315
+ n_past = min (n_ctx - len (batch ), len (self ._input_ids ))
314
316
n_tokens = len (batch )
315
317
return_code = llama_cpp .llama_eval (
316
318
ctx = self .ctx ,
@@ -356,6 +358,7 @@ def _sample(
356
358
):
357
359
assert self .ctx is not None
358
360
assert len (self .eval_logits ) > 0
361
+ assert self ._scores .shape [0 ] > 0
359
362
n_vocab = self ._n_vocab
360
363
n_ctx = self ._n_ctx
361
364
top_k = llama_cpp .c_int (n_vocab ) if top_k .value <= 0 else top_k
@@ -368,7 +371,7 @@ def _sample(
368
371
369
372
if logits_processor is not None :
370
373
logits = np .array (
371
- logits_processor (list ( self .eval_tokens ), logits .tolist ()),
374
+ logits_processor (self ._input_ids . tolist ( ), logits .tolist ()),
372
375
dtype = np .single ,
373
376
)
374
377
self ._scores [- 1 , :] = logits
@@ -498,8 +501,8 @@ def sample(
498
501
"""
499
502
assert self .ctx is not None
500
503
last_n_tokens_data = [llama_cpp .llama_token (0 )] * max (
501
- 0 , self .last_n_tokens_size - len (self .eval_tokens )
502
- ) + list ( self .eval_tokens ) [- self .last_n_tokens_size :]
504
+ 0 , self .last_n_tokens_size - len (self ._input_ids )
505
+ ) + self ._input_ids [- self .last_n_tokens_size :]. tolist ()
503
506
return self ._sample (
504
507
last_n_tokens_data = (llama_cpp .llama_token * self .last_n_tokens_size )(
505
508
* last_n_tokens_data
@@ -557,9 +560,9 @@ def generate(
557
560
"""
558
561
assert self .ctx is not None
559
562
560
- if reset and len (self .eval_tokens ) > 0 :
563
+ if reset and len (self ._input_ids ) > 0 :
561
564
longest_prefix = 0
562
- for a , b in zip (self .eval_tokens , tokens [:- 1 ]):
565
+ for a , b in zip (self ._input_ids , tokens [:- 1 ]):
563
566
if a == b :
564
567
longest_prefix += 1
565
568
else :
@@ -569,6 +572,8 @@ def generate(
569
572
print ("Llama.generate: prefix-match hit" , file = sys .stderr )
570
573
reset = False
571
574
tokens = tokens [longest_prefix :]
575
+ self ._input_ids = self ._input_ids [:longest_prefix ]
576
+ self ._scores = self ._scores [:longest_prefix , :]
572
577
for _ in range (len (self .eval_tokens ) - longest_prefix ):
573
578
self .eval_tokens .pop ()
574
579
try :
@@ -595,7 +600,7 @@ def generate(
595
600
logits_processor = logits_processor ,
596
601
)
597
602
if stopping_criteria is not None and stopping_criteria (
598
- list ( self .eval_tokens ), self .eval_logits [- 1 ]
603
+ self ._input_ids . tolist ( ), self ._scores [- 1 , :]. tolist ()
599
604
):
600
605
return
601
606
tokens_or_none = yield token
@@ -820,7 +825,7 @@ def _create_completion(
820
825
self .detokenize (completion_tokens [:returned_tokens ])
821
826
)
822
827
token_offset = len (prompt_tokens ) + returned_tokens
823
- logits = self .eval_logits [token_offset - 1 ]
828
+ logits = self ._scores [token_offset - 1 , :]. tolist ()
824
829
current_logprobs = Llama .logits_to_logprobs (logits )
825
830
sorted_logprobs = list (
826
831
sorted (
@@ -869,7 +874,7 @@ def _create_completion(
869
874
break
870
875
871
876
if stopping_criteria is not None and stopping_criteria (
872
- list ( self .eval_tokens ), self .eval_logits [- 1 ]
877
+ self ._input_ids . tolist ( ), self ._scores [- 1 , :]. tolist ()
873
878
):
874
879
text = self .detokenize (completion_tokens )
875
880
finish_reason = "stop"
@@ -899,7 +904,7 @@ def _create_completion(
899
904
self .detokenize (completion_tokens [:returned_tokens ])
900
905
)
901
906
token_offset = len (prompt_tokens ) + returned_tokens - 1
902
- logits = self .eval_logits [token_offset ]
907
+ logits = self ._scores [token_offset , :]. tolist ()
903
908
current_logprobs = Llama .logits_to_logprobs (logits )
904
909
sorted_logprobs = list (
905
910
sorted (
@@ -1001,8 +1006,7 @@ def _create_completion(
1001
1006
for token in all_tokens
1002
1007
]
1003
1008
all_logprobs = [
1004
- Llama .logits_to_logprobs (list (map (float , row )))
1005
- for row in self .eval_logits
1009
+ Llama .logits_to_logprobs (row .tolist ()) for row in self ._scores
1006
1010
][token_offset :]
1007
1011
for token , token_str , logprobs_token in zip (
1008
1012
all_tokens , all_token_strs , all_logprobs
0 commit comments