@@ -53,12 +53,14 @@ class LlamaState:
53
53
def __init__ (
54
54
self ,
55
55
eval_tokens : Deque [llama_cpp .llama_token ],
56
- eval_logits : Deque [List [float ]],
56
+ eval_logits : Deque [List [llama_cpp . c_float ]],
57
57
llama_state ,
58
+ llama_state_size : llama_cpp .c_size_t ,
58
59
):
59
60
self .eval_tokens = eval_tokens
60
61
self .eval_logits = eval_logits
61
62
self .llama_state = llama_state
63
+ self .llama_state_size = llama_state_size
62
64
63
65
64
66
class Llama :
@@ -950,19 +952,23 @@ def save_state(self) -> LlamaState:
950
952
assert self .ctx is not None
951
953
state_size = llama_cpp .llama_get_state_size (self .ctx )
952
954
llama_state = (llama_cpp .c_uint8 * int (state_size ))()
953
- if llama_cpp .llama_copy_state_data (self .ctx , llama_state ) != state_size :
955
+ n_bytes = llama_cpp .llama_copy_state_data (self .ctx , llama_state )
956
+ if int (n_bytes ) > int (state_size ):
954
957
raise RuntimeError ("Failed to copy llama state data" )
958
+ llama_state_compact = (llama_cpp .c_uint8 * int (n_bytes ))()
959
+ llama_cpp .ctypes .memmove (llama_state_compact , llama_state , int (n_bytes ))
955
960
return LlamaState (
956
961
eval_tokens = self .eval_tokens .copy (),
957
962
eval_logits = self .eval_logits .copy (),
958
- llama_state = llama_state ,
963
+ llama_state = llama_state_compact ,
964
+ llama_state_size = n_bytes ,
959
965
)
960
966
961
967
def load_state (self , state : LlamaState ) -> None :
962
968
assert self .ctx is not None
963
969
self .eval_tokens = state .eval_tokens .copy ()
964
970
self .eval_logits = state .eval_logits .copy ()
965
- state_size = llama_cpp . llama_get_state_size ( self . ctx )
971
+ state_size = state . llama_state_size
966
972
if llama_cpp .llama_set_state_data (self .ctx , state .llama_state ) != state_size :
967
973
raise RuntimeError ("Failed to set llama state data" )
968
974
0 commit comments