@@ -84,16 +84,9 @@ def __init__(
84
84
self .params .embedding = embedding
85
85
86
86
self .last_n_tokens_size = last_n_tokens_size
87
- self .last_n_tokens_data = deque (
88
- [llama_cpp .llama_token (0 )] * self .last_n_tokens_size ,
89
- maxlen = self .last_n_tokens_size ,
90
- )
91
- self .tokens_consumed = 0
92
- self .tokens : List [llama_cpp .llama_token ] = []
93
87
self .n_batch = min (n_ctx , n_batch )
94
- self .n_tokens = 0
95
- self .n_past = 0
96
- self .all_logits : List [List [float ]] = [] # TODO: Use an array instead of a list.
88
+ self .eval_tokens : deque [llama_cpp .llama_token ] = deque (maxlen = n_ctx )
89
+ self .eval_logits : deque [List [float ]] = deque (maxlen = n_ctx )
97
90
98
91
### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
99
92
### saving and restoring state, this allows us to continue a completion if the last
@@ -181,14 +174,8 @@ def set_cache(self, cache: Optional[LlamaCache]):
181
174
182
175
def reset (self ):
183
176
"""Reset the model state."""
184
- self .last_n_tokens_data .extend (
185
- [llama_cpp .llama_token (0 )] * self .last_n_tokens_size
186
- )
187
- self .tokens_consumed = 0
188
- self .tokens .clear ()
189
- self .n_tokens = 0
190
- self .n_past = 0
191
- self .all_logits .clear ()
177
+ self .eval_tokens .clear ()
178
+ self .eval_logits .clear ()
192
179
193
180
def eval (self , tokens : Sequence [llama_cpp .llama_token ]):
194
181
"""Evaluate a list of tokens.
@@ -200,32 +187,25 @@ def eval(self, tokens: Sequence[llama_cpp.llama_token]):
200
187
n_ctx = int (llama_cpp .llama_n_ctx (self .ctx ))
201
188
for i in range (0 , len (tokens ), self .n_batch ):
202
189
batch = tokens [i : min (len (tokens ), i + self .n_batch )]
203
- self . n_past = min (n_ctx - len (batch ), self .tokens_consumed )
204
- self . n_tokens = len (batch )
190
+ n_past = min (n_ctx - len (batch ), len ( self .eval_tokens ) )
191
+ n_tokens = len (batch )
205
192
return_code = llama_cpp .llama_eval (
206
193
ctx = self .ctx ,
207
194
tokens = (llama_cpp .llama_token * len (batch ))(* batch ),
208
- n_tokens = llama_cpp .c_int (self . n_tokens ),
209
- n_past = llama_cpp .c_int (self . n_past ),
195
+ n_tokens = llama_cpp .c_int (n_tokens ),
196
+ n_past = llama_cpp .c_int (n_past ),
210
197
n_threads = llama_cpp .c_int (self .n_threads ),
211
198
)
212
199
if int (return_code ) != 0 :
213
200
raise RuntimeError (f"llama_eval returned { return_code } " )
214
- self .tokens .extend (batch )
215
- self .last_n_tokens_data .extend (batch )
216
- self .tokens_consumed += len (batch )
201
+ self .eval_tokens .extend (batch )
217
202
if self .params .logits_all :
218
- self .all_logits .extend (self ._logits ())
219
-
220
- def _logits (self ) -> List [List [float ]]:
221
- """Return the logits from the last call to llama_eval."""
222
- assert self .ctx is not None
223
- n_vocab = llama_cpp .llama_n_vocab (self .ctx )
224
- cols = int (n_vocab )
225
- rows = self .n_tokens if self .params .logits_all else 1
226
- logits_view = llama_cpp .llama_get_logits (self .ctx )
227
- logits = [[logits_view [i * cols + j ] for j in range (cols )] for i in range (rows )]
228
- return logits
203
+ n_vocab = llama_cpp .llama_n_vocab (self .ctx )
204
+ cols = int (n_vocab )
205
+ rows = n_tokens
206
+ logits_view = llama_cpp .llama_get_logits (self .ctx )
207
+ logits = [[logits_view [i * cols + j ] for j in range (cols )] for i in range (rows )]
208
+ self .eval_logits .extend (logits )
229
209
230
210
def sample (
231
211
self ,
@@ -246,10 +226,13 @@ def sample(
246
226
The sampled token.
247
227
"""
248
228
assert self .ctx is not None
229
+ last_n_tokens_data = [llama_cpp .llama_token (0 )] * max (
230
+ 0 , self .last_n_tokens_size - len (self .eval_tokens )
231
+ ) + list (self .eval_tokens )[- self .last_n_tokens_size :]
249
232
return llama_cpp .llama_sample_top_p_top_k (
250
233
ctx = self .ctx ,
251
234
last_n_tokens_data = (llama_cpp .llama_token * self .last_n_tokens_size )(
252
- * self . last_n_tokens_data
235
+ * last_n_tokens_data
253
236
),
254
237
last_n_tokens_size = llama_cpp .c_int (self .last_n_tokens_size ),
255
238
top_k = llama_cpp .c_int (top_k ),
@@ -293,13 +276,13 @@ def generate(
293
276
if (
294
277
reset
295
278
and self ._cache
296
- and len (self .tokens ) > 0
297
- and self .tokens == tokens [: len (self .tokens )]
279
+ and len (self .eval_tokens ) > 0
280
+ and self .eval_tokens == tokens [: len (self .eval_tokens )]
298
281
):
299
282
if self .verbose :
300
283
print ("generate cache hit" , file = sys .stderr )
301
284
reset = False
302
- tokens = tokens [len (self .tokens ) :]
285
+ tokens = tokens [len (self .eval_tokens ) :]
303
286
###
304
287
if reset :
305
288
self .reset ()
@@ -537,7 +520,7 @@ def _create_completion(
537
520
]
538
521
all_logprobs = [
539
522
[Llama .logit_to_logprob (logit ) for logit in row ]
540
- for row in self .all_logits
523
+ for row in self .eval_logits
541
524
]
542
525
for token , token_str , logprobs_token in zip (
543
526
all_tokens , all_token_strs , all_logprobs
0 commit comments