8000 Add experimental cache · GuoqiangJia/llama-cpp-python@92c0771 · GitHub
[go: up one dir, main page]

Skip to content

Commit 92c0771

Browse files
committed
Add experimental cache
1 parent a6372a7 commit 92c0771

File tree

2 files changed

+69
-5
lines changed

2 files changed

+69
-5
lines changed

llama_cpp/llama.py

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,15 @@
1111
from .llama_types import *
1212

1313

14+
class LlamaCache:
15+
"""Cache for a llama.cpp model.
16+
17+
NOTE: This implementation currently only tells the Llama class to avoid reprocessing bytes and continue from the last
18+
completion. It does not actually cache the results."""
19+
20+
pass
21+
22+
1423
class Llama:
1524
"""High-level Python wrapper for a llama.cpp model."""
1625

@@ -82,6 +91,14 @@ def __init__(
8291
self.n_past = 0
8392
self.all_logits: List[List[float]] = [] # TODO: Use an array instead of a list.
8493

94+
### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
95+
### saving and restoring state, this allows us to continue a completion if the last
96+
### completion_bytes is a prefix to the prompt passed in. However this is actually incorrect
97+
### because it does not take into account stop tokens which have been processed by the model.
98+
self._completion_bytes: List[bytes] = []
99+
self._cache: Optional[LlamaCache] = None
100+
###
101+
85102
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
86103

87104
if not os.path.exists(model_path):
@@ -135,6 +152,14 @@ def detokenize(self, tokens: List[llama_cpp.llama_token]) -> bytes:
135152
output += llama_cpp.llama_token_to_str(self.ctx, token)
136153
return output
137154

155+
def set_cache(self, cache: Optional[LlamaCache]):
156+
"""Set the cache.
157+
158+
Args:
159+
cache: The cache to set.
160+
"""
161+
self._cache = cache
162+
138163
def reset(self):
139164
"""Reset the model state."""
140165
self.last_n_tokens_data.extend(
@@ -245,6 +270,17 @@ def generate(
245270
The generated tokens.
246271
"""
247272
assert self.ctx is not None
273+
### HACK
274+
if (
275+
reset
276+
and self._cache
277+
and len(self.tokens) > 0
278+
and self.tokens == tokens[: len(self.tokens)]
279+
):
280+
if self.verbose:
281+
print("generate cache hit", file=sys.stderr)
282+
reset = False
283+
###
248284
if reset:
249285
self.reset()
250286
while True:
@@ -361,13 +397,29 @@ def _create_completion(
361397
"logprobs is not supported for models created with logits_all=False"
362398
)
363399

400+
### HACK
401+
reset: bool = True
402+
_prompt: bytes = prompt.encode("utf-8")
403+
_completion: bytes = b"".join(self._completion_bytes)
404+
if len(_completion) and self._cache and _prompt.startswith(_completion):
405+
if self.verbose:
406+
print("completion cache hit", file=sys.stderr)
407+
reset = False
408+
_prompt = _prompt[len(_completion) :]
409+
prompt_tokens = self.tokenize(b" " + _prompt)
410+
self._completion_bytes.append(_prompt)
411+
else:
412+
self._completion_bytes = [prompt.encode("utf-8")]
413+
###
414+
364415
finish_reason = "length"
365416
for token in self.generate(
366417
prompt_tokens,
367418
top_k=top_k,
368419
top_p=top_p,
369420
temp=temperature,
370421
repeat_penalty=repeat_penalty,
422+
reset=reset,
371423
):
372424
if token == llama_cpp.llama_token_eos():
373425
text = self.detokenize(completion_tokens)
@@ -397,6 +449,9 @@ def _create_completion(
397449
break
398450
text = all_text[: len(all_text) - longest]
399451
returned_characters += len(text[start:])
452+
### HACK
453+
self._completion_bytes.append(text[start:])
454+
###
400455
yield {
401456
"id": completion_id,
402457
"object": "text_completion",
@@ -418,6 +473,9 @@ def _create_completion(
418473
break
419474

420475
if stream:
476+
### HACK
477+
self._completion_bytes.append(text[returned_characters:])
478+
###
421479
yield {
422480
"id": completion_id,
423481
"object": "text_completion",
@@ -434,13 +492,16 @@ def _create_completion(
434492
}
435493
return
436494

437-
text = text.decode("utf-8")
495+
### HACK
496+
self._completion_bytes.append(text)
497+
###
498+
text_str = text.decode("utf-8")
438499

439500
if echo:
440-
text = prompt + text
501+
text_str = prompt + text_str
441502

442503
if suffix is not None:
443-
text = text + suffix
504+
text_str = text_str + suffix
444505

445506
logprobs_or_none: Optional[CompletionLogprobs] = None
446507
if logprobs is not None:
@@ -493,7 +554,7 @@ def _create_completion(
493554
"model": self.model_path,
494555
"choices": [
495556
{
496-
"text": text,
557+
"text": text_str,
497558
"index": 0,
498559
"logprobs": logprobs_or_none,
499560
"finish_reason": finish_reason,

llama_cpp/server/__main__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class Settings(BaseSettings):
3535
embedding: bool = True
3636
last_n_tokens_size: int = 64
3737
logits_all: bool = False
38+
cache: bool = False # WARNING: This is an experimental feature
3839

3940

4041
app = FastAPI(
@@ -60,6 +61,9 @@ class Settings(BaseSettings):
6061
n_ctx=settings.n_ctx,
6162
last_n_tokens_size=settings.last_n_tokens_size,
6263
)
64+
if settings.cache:
65+
cache = llama_cpp.LlamaCache()
66+
llama.set_cache(cache)
6367
llama_lock = Lock()
6468

6569

@@ -68,7 +72,6 @@ def get_llama():
6872
yield llama
6973

7074

71-
7275
class CreateCompletionRequest(BaseModel):
7376
prompt: Union[str, List[str]]
7477
suffix: Optional[str] = Field(None)

0 commit comments

Comments
 (0)
0