8000 Support smaller state sizes · Stonelinks/llama-cpp-python@43f2907 · GitHub
[go: up one dir, main page]

Skip to content

Commit 43f2907

Browse files
committed
Support smaller state sizes
1 parent 1d47cce commit 43f2907

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

llama_cpp/llama.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,14 @@ class LlamaState:
5353
def __init__(
5454
self,
5555
eval_tokens: Deque[llama_cpp.llama_token],
56-
eval_logits: Deque[List[float]],
56+
eval_logits: Deque[List[llama_cpp.c_float]],
5757
llama_state,
58+
llama_state_size: llama_cpp.c_size_t,
5859
):
5960
self.eval_tokens = eval_tokens
6061
self.eval_logits = eval_logits
6162
self.llama_state = llama_state
63+
self.llama_state_size = llama_state_size
6264

6365

6466
class Llama:
@@ -950,19 +952,23 @@ def save_state(self) -> LlamaState:
950952
assert self.ctx is not None
951953
state_size = llama_cpp.llama_get_state_size(self.ctx)
952954
llama_state = (llama_cpp.c_uint8 * int(state_size))()
953-
if llama_cpp.llama_copy_state_data(self.ctx, llama_state) != state_size:
955+
n_bytes = llama_cpp.llama_copy_state_data(self.ctx, llama_state)
956+
if int(n_bytes) > int(state_size):
954957
raise RuntimeError("Failed to copy llama state data")
958+
llama_state_compact = (llama_cpp.c_uint8 * int(n_bytes))()
959+
llama_cpp.ctypes.memmove(llama_state_compact, llama_state, int(n_bytes))
955960
return LlamaState(
956961
eval_tokens=self.eval_tokens.copy(),
957962
eval_logits=self.eval_logits.copy(),
958-
llama_state=llama_state,
963+
llama_state=llama_state_compact,
964+
llama_state_size=n_bytes,
959965
)
960966

961967
def load_state(self, state: LlamaState) -> None:
962968
assert self.ctx is not None
963969
self.eval_tokens = state.eval_tokens.copy()
964970
self.eval_logits = state.eval_logits.copy()
965-
state_size = llama_cpp.llama_get_state_size(self.ctx)
971+
state_size = state.llama_state_size
966972
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
967973
raise RuntimeError("Failed to set llama state data")
968974

0 commit comments

Comments
 (0)
0