4
4
import time
5
5
import math
6
6
import multiprocessing
7
- from typing import List , Optional , Union , Generator , Sequence , Iterator
7
+ from typing import List , Optional , Union , Generator , Sequence , Iterator , Deque
8
8
from collections import deque
9
9
10
10
from . import llama_cpp
@@ -20,6 +20,18 @@ class LlamaCache:
20
20
pass
21
21
22
22
23
+ class LlamaState :
24
+ def __init__ (
25
+ self ,
26
+ eval_tokens : Deque [llama_cpp .llama_token ],
27
+ eval_logits : Deque [List [float ]],
28
+ llama_state ,
29
+ ):
30
+ self .eval_tokens = eval_tokens
31
+ self .eval_logits = eval_logits
32
+ self .llama_state = llama_state
33
+
34
+
23
35
class Llama :
24
36
"""High-level Python wrapper for a llama.cpp model."""
25
37
@@ -85,8 +97,8 @@ def __init__(
85
97
86
98
self .last_n_tokens_size = last_n_tokens_size
87
99
self .n_batch = min (n_ctx , n_batch )
88
- self .eval_tokens : deque [llama_cpp .llama_token ] = deque (maxlen = n_ctx )
89
- self .eval_logits : deque [List [float ]] = deque (maxlen = n_ctx )
100
+ self .eval_tokens : Deque [llama_cpp .llama_token ] = deque (maxlen = n_ctx )
101
+ self .eval_logits : Deque [List [float ]] = deque (maxlen = n_ctx )
90
102
91
103
### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
92
104
### saving and restoring state, this allows us to continue a completion if the last
@@ -204,7 +216,10 @@ def eval(self, tokens: Sequence[llama_cpp.llama_token]):
204
216
cols = int (n_vocab )
205
217
rows = n_tokens
206
218
logits_view = llama_cpp .llama_get_logits (self .ctx )
207
- logits = [[logits_view [i * cols + j ] for j in range (cols )] for i in range (rows )]
219
+ logits = [
220
+ [logits_view [i * cols + j ] for j in range (cols )]
221
+ for i in range (rows )
222
+ ]
208
223
self .eval_logits .extend (logits )
209
224
210
225
def sample (
@@ -828,6 +843,26 @@ def __setstate__(self, state):
828
843
verbose = state ["verbose" ],
829
844
)
830
845
846
+ def save_state (self ) -> LlamaState :
847
+ assert self .ctx is not None
848
+ state_size = llama_cpp .llama_get_state_size (self .ctx )
849
+ llama_state = (llama_cpp .c_uint8 * int (state_size ))()
850
+ if llama_cpp .llama_copy_state_data (self .ctx , llama_state ) != state_size :
851
+ raise RuntimeError ("Failed to copy llama state data" )
852
+ return LlamaState (
853
+ eval_tokens = self .eval_tokens .copy (),
854
+ eval_logits = self .eval_logits .copy (),
855
+ llama_state = llama_state ,
856
+ )
857
+
858
+ def load_state (self , state : LlamaState ) -> None :
859
+ assert self .ctx is not None
860
+ self .eval_tokens = state .eval_tokens .copy ()
861
+ self .eval_logits = state .eval_logits .copy ()
862
+ state_size = llama_cpp .llama_get_state_size (self .ctx )
863
+ if llama_cpp .llama_set_state_data (self .ctx , state .llama_state ) != state_size :
864
+ raise RuntimeError ("Failed to set llama state data" )
865
+
831
866
@staticmethod
832
867
def token_eos () -> llama_cpp .llama_token :
833
868
"""Return the end-of-sequence token."""
0 commit comments