@@ -63,6 +63,11 @@ def __init__(
63
63
self .params .embedding = embedding
64
64
65
65
self .last_n_tokens_size = last_n_tokens_size
66
+ self .last_n_tokens_data = deque (
67
+ [llama_cpp .llama_token (0 )] * self .last_n_tokens_size ,
68
+ maxlen = self .last_n_tokens_size ,
69
+ )
70
+ self .tokens_consumed = 0
66
71
self .n_batch = n_batch
67
72
68
73
self .n_threads = n_threads or multiprocessing .cpu_count ()
@@ -115,6 +120,67 @@ def detokenize(self, tokens: List[llama_cpp.llama_token]) -> bytes:
115
120
output += llama_cpp .llama_token_to_str (self .ctx , token )
116
121
return output
117
122
123
+ def reset (self ):
124
+ """Reset the model state."""
125
+ self .last_n_tokens_data .extend (
126
+ [llama_cpp .llama_token (0 )] * self .last_n_tokens_size
127
+ )
128
+ self .tokens_consumed = 0
129
+
130
+ def eval (self , tokens : Sequence [llama_cpp .llama_token ]):
131
+ """Evaluate a list of tokens.
132
+
133
+ Args:
134
+ tokens: The list of tokens to evaluate.
135
+ """
136
+ assert self .ctx is not None
137
+ n_ctx = int (llama_cpp .llama_n_ctx (self .ctx ))
138
+ for i in range (0 , len (tokens ), self .n_batch ):
139
+ batch = tokens [i : min (len (tokens ), i + self .n_batch )]
140
+ n_past = min (n_ctx - len (batch ), self .tokens_consumed )
141
+ return_code = llama_cpp .llama_eval (
142
+ ctx = self .ctx ,
143
+ tokens = (llama_cpp .llama_token * len (batch ))(* batch ),
144
+ n_tokens = llama_cpp .c_int (len (batch )),
145
+ n_past = llama_cpp .c_int (n_past ),
146
+ n_threads = llama_cpp .c_int (self .n_threads ),
147
+ )
148
+ if int (return_code ) != 0 :
149
+ raise RuntimeError (f"llama_eval returned { return_code } " )
150
+ self .last_n_tokens_data .extend (batch )
151
+ self .tokens_consumed += len (batch )
152
+
153
+ def sample (
154
+ self ,
155
+ top_k : int ,
156
+ top_p : float ,
157
+ temp : float ,
158
+ repeat_penalty : float ,
159
+ ):
160
+ """Sample a token from the model.
161
+
162
+ Args:
163
+ top_k: The top-k sampling parameter.
164
+ top_p: The top-p sampling parameter.
165
+ temp: The temperature parameter.
166
+ repeat_penalty: The repeat penalty parameter.
167
+
168
+ Returns:
169
+ The sampled token.
170
+ """
171
+ assert self .ctx is not None
172
+ return llama_cpp .llama_sample_top_p_top_k (
173
+ ctx = self .ctx ,
174
+ last_n_tokens_data = (llama_cpp .llama_token * self .last_n_tokens_size )(
175
+ * self .last_n_tokens_data
176
+ ),
177
+ last_n_tokens_size = llama_cpp .c_int (self .last_n_tokens_size ),
178
+ top_k = llama_cpp .c_int (top_k ),
179
+ top_p = llama_cpp .c_float (top_p ),
180
+ temp = llama_cpp .c_float (temp ),
181
+ repeat_penalty = llama_cpp .c_float (repeat_penalty ),
182
+ )
183
+
118
184
def generate (
119
185
self ,
120
186
tokens : Sequence [llama_cpp .llama_token ],
@@ -125,7 +191,7 @@ def generate(
125
191
) -> Generator [
126
192
llama_cpp .llama_token , Optional [Sequence [llama_cpp .llama_token ]], None
127
193
]:
128
- """Generate tokens.
194
+ """Create a generator of tokens from a prompt .
129
195
130
196
Examples:
131
197
>>> llama = Llama("models/ggml-7b.bin")
@@ -149,37 +215,14 @@ def generate(
149
215
top_p = 0.0
150
216
top_k = 1
151
217
assert self .ctx is not None
152
- n_ctx = int (llama_cpp .llama_n_ctx (self .ctx ))
153
- n_tokens = 0
154
- last_n_tokens = deque (
155
- [llama_cpp .llama_token (0 )] * self .last_n_tokens_size ,
156
- maxlen = self .last_n_tokens_size ,
157
- )
218
+ self .reset ()
158
219
while True :
159
- for i in range (0 , len (tokens ), self .n_batch ):
160
- batch = tokens [i : min (len (tokens ), i + self .n_batch )]
161
- n_past = min (n_ctx - len (batch ), n_tokens )
162
- return_code = llama_cpp .llama_eval (
163
- ctx = self .ctx ,
164
- tokens = (llama_cpp .llama_token * len (batch ))(* batch ),
165
- n_tokens = llama_cpp .c_int (len (batch )),
166
- n_past = llama_cpp .c_int (n_past ),
167
- n_threads = llama_cpp .c_int (self .n_threads ),
168
- )
169
- if int (return_code ) != 0 :
170
- raise RuntimeError (f"llama_eval returned { return_code } " )
171
- last_n_tokens .extend (batch )
172
- n_tokens += len (batch )
173
- token = llama_cpp .llama_sample_top_p_top_k (
174
- ctx = self .ctx ,
175
- last_n_tokens_data = (llama_cpp .llama_token * self .last_n_tokens_size )(
176
- * last_n_tokens
177
- ),
178
- last_n_tokens_size = llama_cpp .c_int (self .last_n_tokens_size ),
179
- top_k = llama_cpp .c_int (top_k ),
180
- top_p = llama_cpp .c_float (top_p ),
181
- temp = llama_cpp .c_float (temp ),
182
- repeat_penalty = llama_cpp .c_float (repeat_penalty ),
220
+ self .eval (tokens )
221
+ token = self .sample (
222
+ top_k = top_k ,
223
+ top_p = top_p ,
224
+ temp = temp ,
225
+ repeat_penalty = repeat_penalty ,
183
226
)
184
227
tokens_or_none = yield token
185
228
tokens = [token ]
@@ -197,7 +240,8 @@ def create_embedding(self, input: str) -> Embedding:
197
240
"""
198
241
assert self .ctx is not None
199
242
tokens = self .tokenize (input .encode ("utf-8" ))
200
- next (self .generate (tokens , top_k = 0 , top_p = 0.0 , temp = 1.0 , repeat_penalty = 1.0 ))
243
+ self .reset ()
244
+ self .eval (tokens )
201
245
n_tokens = len (tokens )
202
246
embedding = llama_cpp .llama_get_embeddings (self .ctx )[
203
247
: llama_cpp .llama_n_embd (self .ctx )
0 commit comments