@@ -141,7 +141,7 @@ def __getitem__(self, key: Sequence[int]) -> "LlamaState":
141
141
if _key is None :
142
142
raise KeyError ("Key not found" )
143
143
value : "LlamaState" = self .cache .pop (_key ) # type: ignore
144
- # NOTE: This puts an integer as key in cache, which breaks,
144
+ # NOTE: This puts an integer as key in cache, which breaks,
145
145
# Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
146
146
# self.cache.push(_key, side="front") # type: ignore
147
147
return value
@@ -166,17 +166,15 @@ def __setitem__(self, key: Sequence[int], value: "LlamaState"):
166
166
class LlamaState :
167
167
def __init__ (
168
168
self ,
169
- eval_tokens : Deque [int ],
170
- eval_logits : Deque [List [float ]],
171
169
input_ids : npt .NDArray [np .intc ],
172
170
scores : npt .NDArray [np .single ],
171
+ n_tokens : int ,
173
172
llama_state : bytes ,
174
173
llama_state_size : int ,
175
174
):
176
- self .eval_tokens = eval_tokens
177
- self .eval_logits = eval_logits
178
175
self .input_ids = input_ids
179
176
self .scores = scores
177
+ self .n_tokens = n_tokens
180
178
self .llama_state = llama_state
181
179
self .llama_state_size = llama_state_size
182
180
@@ -267,8 +265,6 @@ def __init__(
267
265
268
266
self .last_n_tokens_size = last_n_tokens_size
269
267
self .n_batch = min (n_ctx , n_batch )
270
- self .eval_tokens : Deque [int ] = deque (maxlen = n_ctx )
271
- self .eval_logits : Deque [List [float ]] = deque (maxlen = n_ctx if logits_all else 1 )
272
268
273
269
self .cache : Optional [BaseLlamaCache ] = None
274
270
@@ -329,8 +325,30 @@ def __init__(
329
325
self ._token_nl = Llama .token_nl ()
330
326
self ._token_eos = Llama .token_eos ()
331
327
332
- self ._input_ids = np .array ([], dtype = np .intc )
333
- self ._scores
8000
: npt .NDArray [np .single ] = np .ndarray ((0 , self ._n_vocab ), dtype = np .single )
328
+ self .n_tokens = 0
329
+ self .input_ids : npt .NDArray [np .intc ] = np .ndarray ((n_ctx ,), dtype = np .intc )
330
+ self .scores : npt .NDArray [np .single ] = np .ndarray (
331
+ (n_ctx , self ._n_vocab ), dtype = np .single
332
+ )
333
+
334
+ @property
335
+ def _input_ids (self ) -> npt .NDArray [np .intc ]:
336
+ return self .input_ids [: self .n_tokens ]
337
+
338
+ @property
339
+ def _scores (self ) -> npt .NDArray [np .single ]:
340
+ return self .scores [: self .n_tokens , :]
341
+
342
+ @property
343
+ def eval_tokens (self ) -> Deque [int ]:
344
+ return deque (self .input_ids [: self .n_tokens ].tolist (), maxlen = self ._n_ctx )
345
+
346
+ @property
347
+ def eval_logits (self ) -> Deque [List [float ]]:
348
+ return deque (
349
+ self .scores [: self .n_tokens , :].tolist (),
350
+ maxlen = self ._n_ctx if self .params .logits_all else 1 ,
351
+ )
334
352
335
353
def tokenize (self , text : bytes , add_bos : bool = True ) -> List [int ]:
336
354
"""Tokenize a string.
@@ -397,10 +415,7 @@ def set_cache(self, cache: Optional[BaseLlamaCache]):
397
415
398
416
def reset (self ):
399
417
"""Reset the model state."""
400
- self .eval_tokens .clear ()
401
- self .eval_logits .clear ()
402
- self ._input_ids = np .array ([], dtype = np .intc )
403
- self ._scores = np .ndarray ((0 , self ._n_vocab ), dtype = np .single )
418
+ self .n_tokens = 0
404
419
405
420
def eval (self , tokens : Sequence [int ]):
406
421
"""Evaluate a list of tokens.
@@ -410,7 +425,6 @@ def eval(self, tokens: Sequence[int]):
410
425
"""
411
426
assert self .ctx is not None
412
427
n_ctx = self ._n_ctx
413
- scores : List [npt .NDArray [np .single ]] = []
414
428
for i in range (0 , len (tokens ), self .n_batch ):
415
429
batch = tokens [i : min (len (tokens ), i + self .n_batch )]
416
430
n_past = min (n_ctx - len (batch ), len (self ._input_ids ))
@@ -425,19 +439,16 @@ def eval(self, tokens: Sequence[int]):
425
439
if return_code != 0 :
426
440
raise RuntimeError (f"llama_eval returned { return_code } " )
427
441
# Save tokens
428
- self .eval_tokens .extend (batch )
429
- self ._input_ids : npt .NDArray [np .intc ] = np .concatenate (
430
- (self ._input_ids , np .array (batch , dtype = np .intc )), axis = 0
431
- )
442
+ self .input_ids [self .n_tokens : self .n_tokens + n_tokens ] = batch
432
443
# Save logits
433
444
rows = n_tokens if self .params .logits_all else 1
434
445
n_vocab = self ._n_vocab
435
446
cols = n_vocab
436
447
logits_view = llama_cpp .llama_get_logits (self .ctx )
437
448
logits = [logits_view [i * cols : (i + 1 ) * cols ] for i in range (rows )]
438
- self .eval_logits . extend ( logits )
439
- scores . append ( np . array ( logits , dtype = np . single ))
440
- self ._scores = np . concatenate ( scores )
449
+ self .scores [ self . n_tokens : self . n_tokens + n_tokens , :] = logits
450
+ # Update n_tokens
451
+ self .n_tokens += n_tokens
441
452
442
453
def _sample (
443
454
self ,
@@ -457,8 +468,7 @@ def _sample(
457
468
logits_processor : Optional [LogitsProcessorList ] = None ,
458
469
):
459
470
assert self .ctx is not None
460
- assert len (self .eval_logits ) > 0
461
- assert self ._scores .shape [0 ] > 0
471
+ assert self .n_tokens > 0
462
472
n_vocab = self ._n_vocab
463
473
n_ctx = self ._n_ctx
464
474
top_k = llama_cpp .c_int (n_vocab ) if top_k .value <= 0 else top_k
@@ -475,7 +485,6 @@ def _sample(
475
485
dtype = np .single ,
476
486
)
477
487
self ._scores [- 1 , :] = logits
478
- self .eval_logits [- 1 ] = logits .tolist ()
479
488
480
489
nl_logit = logits [self ._token_nl ]
481
490
candidates = self ._candidates
@@ -672,14 +681,7 @@ def generate(
672
681
print ("Llama.generate: prefix-match hit" , file = sys .stderr )
673
682
reset = False
674
683
tokens = tokens [longest_prefix :]
675
- self ._input_ids = self ._input_ids [:longest_prefix ]
676
- self ._scores = self ._scores [:longest_prefix , :]
677
- for _ in range (len (self .eval_tokens ) - longest_prefix ):
678
- self .eval_tokens .pop ()
679
- try :
680
- self .eval_logits .pop ()
681
- except IndexError :
682
- pass
684
+ self .n_tokens = longest_prefix
683
685
684
686
if reset :
685
687
self .reset ()
@@ -819,7 +821,9 @@ def _create_completion(
819
821
llama_cpp .llama_reset_timings (self .ctx )
820
822
821
823
if len (prompt_tokens ) > self ._n_ctx :
822
- raise ValueError (f"Requested tokens ({ len (prompt_tokens )} ) exceed context window of { self ._n_ctx } " )
824
+ raise ValueError (
825
+ f"Requested tokens ({ len (prompt_tokens )} ) exceed context window of { self ._n_ctx } "
826
+ )
823
827
824
828
# Truncate max_tokens if requested tokens would exceed the context window
825
829
max_tokens = (
@@ -1513,22 +1517,20 @@ def save_state(self) -> LlamaState:
1513
1517
file = sys .stderr ,
1514
1518
)
1515
1519
return LlamaState (
1516
- eval_tokens = self .eval_tokens .copy (),
1517
- eval_logits = self .eval_logits .copy (),
1518
- scores = self ._scores .copy (),
1519
- input_ids = self ._input_ids .copy (),
1520
+ scores = self .scores .copy (),
1521
+ input_ids = self .input_ids .copy (),
1522
+ n_tokens = self .n_tokens ,
1520
1523
llama_state = bytes (llama_state_compact ),
1521
1524
llama_state_size = n_bytes ,
1522
1525
)
1523
1526
1524
1527
def load_state (self , state : LlamaState ) -> None :
1525
1528
assert self .ctx is not None
1526
- self .eval_tokens = state .eval_tokens .copy ()
1527
- self .eval_logits = state .eval_logits .copy ()
1528
- self ._scores = state .scores .copy ()
1529
- self ._input_ids = state .input_ids .copy ()
1529
+ self .scores = state .scores .copy ()
1530
+ self .input_ids = state .input_ids .copy ()
1531
+ self .n_tokens = state .n_tokens
1530
1532
state_size = state .llama_state_size
1531
- LLamaStateArrayType = ( llama_cpp .c_uint8 * state_size )
1533
+ LLamaStateArrayType = llama_cpp .c_uint8 * state_size
1532
1534
llama_state = LLamaStateArrayType .from_buffer_copy (state .llama_state )
1533
1535
1534
1536
if llama_cpp .llama_set_state_data (self .ctx , llama_state ) != state_size :
0 commit comments