12
12
13
13
14
14
class LlamaCache :
15
- """Cache for a llama.cpp model.
15
+ """Cache for a llama.cpp model."""
16
16
17
- NOTE: This implementation currently only tells the Llama class to avoid reprocessing bytes and continue from the last
18
- completion. It does not actually cache the results."""
17
+ def __init__ ( self ):
18
+ self . cache_state : Dict [ Sequence [ llama_cpp . llama_token ], "LlamaState" ] = dict ()
19
19
20
- pass
20
+ def __getitem__ (
21
+ self , key : Sequence [llama_cpp .llama_token ]
22
+ ) -> Optional ["LlamaState" ]:
23
+ return self .cache_state .get (tuple (key ), None )
24
+
25
+ def __contains__ (self , key : Sequence [llama_cpp .llama_token ]) -> bool :
26
+ return tuple (key ) in self .cache_state
27
+
28
+ def __setitem__ (self , key : Sequence [llama_cpp .llama_token ], value : "LlamaState" ):
29
+ self .cache_state = dict () # NOTE: Currently limit to one cache entry.
30
+ self .cache_state [tuple (key )] = value
21
31
22
32
23
33
class LlamaState :
@@ -100,13 +110,7 @@ def __init__(
100
110
self .eval_tokens : Deque [llama_cpp .llama_token ] = deque (maxlen = n_ctx )
101
111
self .eval_logits : Deque [List [float ]] = deque (maxlen = n_ctx )
102
112
103
- ### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
104
- ### saving and restoring state, this allows us to continue a completion if the last
105
- ### completion_bytes is a prefix to the prompt passed in. However this is actually incorrect
106
- ### because it does not take into account stop tokens which have been processed by the model.
107
- self ._completion_bytes : List [bytes ] = []
108
- self ._cache : Optional [LlamaCache ] = None
109
- ###
113
+ self .cache : Optional [LlamaCache ] = None
110
114
111
115
self .n_threads = n_threads or max (multiprocessing .cpu_count () // 2 , 1 )
112
116
@@ -182,7 +186,7 @@ def set_cache(self, cache: Optional[LlamaCache]):
182
186
Args:
183
187
cache: The cache to set.
184
188
"""
185
- self ._cache = cache
189
+ self .cache = cache
186
190
187
191
def reset (self ):
188
192
"""Reset the model state."""
@@ -287,18 +291,17 @@ def generate(
287
291
The generated tokens.
288
292
"""
289
293
assert self .ctx is not None
290
- ### HACK
294
+
291
295
if (
292
296
reset
293
- and self ._cache
294
297
and len (self .eval_tokens ) > 0
295
298
and self .eval_tokens == tokens [: len (self .eval_tokens )]
296
299
):
297
300
if self .verbose :
298
301
print ("generate cache hit" , file = sys .stderr )
299
302
reset = False
300
303
tokens = tokens [len (self .eval_tokens ) :]
301
- ###
304
+
302
305
if reset :
303
306
self .reset ()
304
307
while True :
@@ -415,20 +418,10 @@ def _create_completion(
415
418
"logprobs is not supported for models created with logits_all=False"
416
419
)
417
420
418
- ### HACK
419
- reset : bool = True
420
- _prompt : bytes = prompt .encode ("utf-8" )
421
- _completion : bytes = b"" .join (self ._completion_bytes )
422
- if len (_completion ) and self ._cache and _prompt .startswith (_completion ):
421
+ if self .cache and prompt_tokens in self .cache :
423
422
if self .verbose :
424
- print ("completion cache hit" , file = sys .stderr )
425
- reset = False
426
- _prompt = _prompt [len (_completion ) :]
427
- prompt_tokens = self .tokenize (b" " + _prompt )
428
- self ._completion_bytes .append (_prompt )
429
- else :
430
- self ._completion_bytes = [prompt .encode ("utf-8" )]
431
- ###
423
+ print ("cache hit" , file = sys .stderr )
424
+ self .load_state (self .cache [prompt_tokens ])
432
425
433
426
finish_reason = "length"
434
427
for token in self .generate (
@@ -437,12 +430,16 @@ def _create_completion(
437
430
top_p = top_p ,
438
431
temp = temperature ,
439
432
repeat_penalty = repeat_penalty ,
440
- reset = reset ,
441
433
):
442
434
if token == llama_cpp .llama_token_eos ():
443
435
text = self .detokenize (completion_tokens )
444
436
finish_reason = "stop"
445
437
break
438
+
439
+ if self .cache and len (completion_tokens ) == 0 :
440
+ if prompt_tokens not in self .cache :
441
+ self .cache [prompt_tokens ] = self .save_state ()
442
+
446
443
completion_tokens .append (token )
447
444
448
445
all_text = self .detokenize (completion_tokens )
@@ -467,9 +464,6 @@ def _create_completion(
467
464
break
468
465
text = all_text [: len (all_text ) - longest ]
469
466
returned_characters += len (text [start :])
470
- ### HACK
471
- self ._completion_bytes .append (text [start :])
472
- ###
473
467
yield {
474
468
"id" : completion_id ,
475
469
"object" : "text_completion" ,
@@ -491,9 +485,6 @@ def _cre
341A
ate_completion(
491
485
break
492
486
493
487
if stream :
494
- ### HACK
495
- self ._completion_bytes .append (text [returned_characters :])
496
- ###
497
488
yield {
498
489
"id" : completion_id ,
499
490
"object" : "text_completion" ,
@@ -510,9 +501,6 @@ def _create_completion(
510
501
}
511
502
return
512
503
513
- ### HACK
514
- self ._completion_bytes .append (text )
515
- ###
516
504
text_str = text .decode ("utf-8" )
517
505
518
506
if echo :
0 commit comments