@@ -324,7 +324,7 @@ def __init__(
324
324
self ._candidates = candidates
325
325
self ._token_nl = Llama .token_nl ()
326
326
self ._token_eos = Llama .token_eos ()
327
- self ._candidates_data_id = np .arange (self ._n_vocab , dtype = np .intc ) # type: ignore
327
+ self ._candidates_data_id = np .arange (self ._n_vocab , dtype = np .intc ) # type: ignore
328
328
self ._candidates_data_p = np .zeros (self ._n_vocab , dtype = np .single )
329
329
330
330
self .n_tokens = 0
@@ -445,8 +445,12 @@ def eval(self, tokens: Sequence[int]):
445
445
# Save logits
446
446
rows = n_tokens if self .params .logits_all else 1
447
447
cols = self ._n_vocab
448
- offset = 0 if self .params .logits_all else n_tokens - 1 # NOTE: Only save the last token logits if logits_all is False
449
- self .scores [self .n_tokens + offset : self .n_tokens + n_tokens , :].reshape (- 1 )[:] = llama_cpp .llama_get_logits (self .ctx )[:rows * cols ]
448
+ offset = (
449
+ 0 if self .params .logits_all else n_tokens - 1
450
+ ) # NOTE: Only save the last token logits if logits_all is False
451
+ self .scores [self .n_tokens + offset : self .n_tokens + n_tokens , :].reshape (
452
+ - 1
453
+ )[:] = llama_cpp .llama_get_logits (self .ctx )[: rows * cols ]
450
454
# Update n_tokens
451
455
self .n_tokens += n_tokens
452
456
@@ -491,7 +495,7 @@ def _sample(
491
495
candidates_data = self ._candidates_data
492
496
candidates_data ["id" ][:] = self ._candidates_data_id # type: ignore
493
497
candidates_data ["logit" ][:] = logits
494
- candidates_data ["p" ][:] = self ._candidates_data_p # type: ignore
498
+ candidates_data ["p" ][:] = self ._candidates_data_p # type: ignore
495
499
candidates .data = candidates_data .ctypes .data_as (llama_cpp .llama_token_data_p )
496
500
candidates .sorted = llama_cpp .c_bool (False )
497
501
candidates .size = llama_cpp .c_size_t (n_vocab )
@@ -537,7 +541,7 @@ def _sample(
537
541
mirostat_mu = llama_cpp .c_float (2.0 * mirostat_tau .value )
538
542
llama_cpp .llama_sample_temperature (
539
543
ctx = self .ctx ,
540
- candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
544
+ candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
541
545
temp = temp ,
542
546
)
543
547
return llama_cpp .llama_sample_token_mirostat_v2 (
0 commit comments