8000 Add cache implementation using llama state · lidanger/llama-cpp-python@cbe95bb · GitHub
[go: up one dir, main page]

Skip to content

Commit cbe95bb

Browse files
committed
Add cache implementation using llama state
1 parent 2c359a2 commit cbe95bb

File tree

1 file changed

+26
-38
lines changed

1 file changed

+26
-38
lines changed

llama_cpp/llama.py

Lines changed: 26 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,22 @@
1212

1313

1414
class LlamaCache:
15-
"""Cache for a llama.cpp model.
15+
"""Cache for a llama.cpp model."""
1616

17-
NOTE: This implementation currently only tells the Llama class to avoid reprocessing bytes and continue from the last
18-
completion. It does not actually cache the results."""
17+
def __init__(self):
18+
self.cache_state: Dict[Sequence[llama_cpp.llama_token], "LlamaState"] = dict()
1919

20-
pass
20+
def __getitem__(
21+
self, key: Sequence[llama_cpp.llama_token]
22+
) -> Optional["LlamaState"]:
23+
return self.cache_state.get(tuple(key), None)
24+
25+
def __contains__(self, key: Sequence[llama_cpp.llama_token]) -> bool:
26+
return tuple(key) in self.cache_state
27+
28+
def __setitem__(self, key: Sequence[llama_cpp.llama_token], value: "LlamaState"):
29+
self.cache_state = dict() # NOTE: Currently limit to one cache entry.
30+
self.cache_state[tuple(key)] = value
2131

2232

2333
class LlamaState:
@@ -100,13 +110,7 @@ def __init__(
100110
self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
101111
self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx)
102112

103-
### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
104-
### saving and restoring state, this allows us to continue a completion if the last
105-
### completion_bytes is a prefix to the prompt passed in. However this is actually incorrect
106-
### because it does not take into account stop tokens which have been processed by the model.
107-
self._completion_bytes: List[bytes] = []
108-
self._cache: Optional[LlamaCache] = None
109-
###
113+
self.cache: Optional[LlamaCache] = None
110114

111115
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
112116

@@ -182,7 +186,7 @@ def set_cache(self, cache: Optional[LlamaCache]):
182186
Args:
183187
cache: The cache to set.
184188
"""
185-
self._cache = cache
189+
self.cache = cache
186190

187191
def reset(self):
188192
"""Reset the model state."""
@@ -287,18 +291,17 @@ def generate(
287291
The generated tokens.
288292
"""
289293
assert self.ctx is not None
290-
### HACK
294+
291295
if (
292296
reset
293-
and self._cache
294297
and len(self.eval_tokens) > 0
295298
and self.eval_tokens == tokens[: len(self.eval_tokens)]
296299
):
297300
if self.verbose:
298301
print("generate cache hit", file=sys.stderr)
299302
reset = False
300303
tokens = tokens[len(self.eval_tokens) :]
301-
###
304+
302305
if reset:
303306
self.reset()
304307
while True:
@@ -415,20 +418,10 @@ def _create_completion(
415418
"logprobs is not supported for models created with logits_all=False"
416419
)
417420

418-
### HACK
419-
reset: bool = True
420-
_prompt: bytes = prompt.encode("utf-8")
421-
_completion: bytes = b"".join(self._completion_bytes)
422-
if len(_completion) and self._cache and _prompt.startswith(_completion):
421+
if self.cache and prompt_tokens in self.cache:
423422
if self.verbose:
424-
print("completion cache hit", file=sys.stderr)
425-
reset = False
426-
_prompt = _prompt[len(_completion) :]
427-
prompt_tokens = self.tokenize(b" " + _prompt)
428-
self._completion_bytes.append(_prompt)
429-
else:
430-
self._completion_bytes = [prompt.encode("utf-8")]
431-
###
423+
print("cache hit", file=sys.stderr)
424+
self.load_state(self.cache[prompt_tokens])
432425

433426
finish_reason = "length"
434427
for token in self.generate(
@@ -437,12 +430,16 @@ def _create_completion(
437430
top_p=top_p,
438431
temp=temperature,
439432
repeat_penalty=repeat_penalty,
440-
reset=reset,
441433
):
442434
if token == llama_cpp.llama_token_eos():
443435
text = self.detokenize(completion_tokens)
444436
finish_reason = "stop"
445437
break
438+
439+
if self.cache and len(completion_tokens) == 0:
440+
if prompt_tokens not in self.cache:
441+
self.cache[prompt_tokens] = self.save_state()
442+
446443
completion_tokens.append(token)
447444

448445
all_text = self.detokenize(completion_tokens)
@@ -467,9 +464,6 @@ def _create_completion(
467464
break
468465
text = all_text[: len(all_text) - longest]
469466
returned_characters += len(text[start:])
470-
### HACK
471-
self._completion_bytes.append(text[start:])
472-
###
473467
yield {
474468
"id": completion_id,
475469
"object": "text_completion",
@@ -491,9 +485,6 @@ def _cre 341A ate_completion(
491485
break
492486

493487
if stream:
494-
### HACK
495-
self._completion_bytes.append(text[returned_characters:])
496-
###
497488
yield {
498489
"id": completion_id,
499490
"object": "text_completion",
@@ -510,9 +501,6 @@ def _create_completion(
510501
}
511502
return
512503

513-
### HACK
514-
self._completion_bytes.append(text)
515-
###
516504
text_str = text.decode("utf-8")
517505

518506
if echo:

0 commit comments

Comments
 (0)
0