8000 Add save/load state api for Llama class · lidanger/llama-cpp-python@197cf80 · GitHub
[go: up one dir, main page]

Skip to content

Commit 197cf80

Browse files
committed
Add save/load state api for Llama class
1 parent c4c332f commit 197cf80

File tree

1 file changed

+39
-4
lines changed

1 file changed

+39
-4
lines changed

llama_cpp/llama.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import time
55
import math
66
import multiprocessing
7-
from typing import List, Optional, Union, Generator, Sequence, Iterator
7+
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque
88
from collections import deque
99

1010
from . import llama_cpp
@@ -20,6 +20,18 @@ class LlamaCache:
2020
pass
2121

2222

23+
class LlamaState:
24+
def __init__(
25+
self,
26+
eval_tokens: Deque[llama_cpp.llama_token],
27+
eval_logits: Deque[List[float]],
28+
llama_state,
29+
):
30+
self.eval_tokens = eval_tokens
31+
self.eval_logits = eval_logits
32+
self.llama_state = llama_state
33+
34+
2335
class Llama:
2436
"""High-level Python wrapper for a llama.cpp model."""
2537

@@ -85,8 +97,8 @@ def __init__(
8597

8698
self.last_n_tokens_size = last_n_tokens_size
8799
self.n_batch = min(n_ctx, n_batch)
88-
self.eval_tokens: deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
89-
self.eval_logits: deque[List[float]] = deque(maxlen=n_ctx)
100+
self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
101+
self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx)
90102

91103
### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
92104
### saving and restoring state, this allows us to continue a completion if the last
@@ -204,7 +216,10 @@ def eval(self, tokens: Sequence[llama_cpp.llama_token]):
204216
cols = int(n_vocab)
205217
rows = n_tokens
206218
logits_view = llama_cpp.llama_get_logits(self.ctx)
207-
logits = [[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)]
219+
logits = [
220+
[logits_view[i * cols + j] for j in range(cols)]
221+
for i in range(rows)
222+
]
208223
self.eval_logits.extend(logits)
209224

210225
def sample(
@@ -828,6 +843,26 @@ def __setstate__(self, state):
828843
verbose=state["verbose"],
829844
)
830845

846+
def save_state(self) -> LlamaState:
847+
assert self.ctx is not None
848+
state_size = llama_cpp.llama_get_state_size(self.ctx)
849+
llama_state = (llama_cpp.c_uint8 * int(state_size))()
850+
if llama_cpp.llama_copy_state_data(self.ctx, llama_state) != state_size:
851+
raise RuntimeError("Failed to copy llama state data")
852+
return LlamaState(
853+
eval_tokens=self.eval_tokens.copy(),
854+
eval_logits=self.eval_logits.copy(),
855+
llama_state=llama_state,
856+
)
857+
858+
def load_state(self, state: LlamaState) -> None:
859+
assert self.ctx is not None
860+
self.eval_tokens = state.eval_tokens.copy()
861+
self.eval_logits = state.eval_logits.copy()
862+
state_size = llama_cpp.llama_get_state_size(self.ctx)
863+
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
864+
raise RuntimeError("Failed to set llama state data")
865+
831866
@staticmethod
832867
def token_eos() -> llama_cpp.llama_token:
833868
"""Return the end-of-sequence token."""

0 commit comments

Comments
 (0)
0