File tree Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -324,6 +324,8 @@ def __init__(
324
324
self ._candidates = candidates
325
325
self ._token_nl = Llama .token_nl ()
326
326
self ._token_eos = Llama .token_eos ()
327
+ self ._candidates_data_id = np .arange (self ._n_vocab , dtype = np .intc ) # type: ignore
328
+ self ._candidates_data_p = np .zeros (self ._n_vocab , dtype = np .single )
327
329
328
330
self .n_tokens = 0
329
331
self .input_ids : npt .NDArray [np .intc ] = np .ndarray ((n_ctx ,), dtype = np .intc )
@@ -487,9 +489,9 @@ def _sample(
487
489
nl_logit = logits [self ._token_nl ]
488
490
candidates = self ._candidates
489
491
candidates_data = self ._candidates_data
490
- candidates_data ["id" ][:] = np . arange ( n_vocab , dtype = np . intc ) # type: ignore
492
+ candidates_data ["id" ][:] = self . _candidates_data_id # type: ignore
491
493
candidates_data ["logit" ][:] = logits
492
- candidates_data ["p" ][:] = np . zeros ( n_vocab , dtype = np . single )
494
+ candidates_data ["p" ][:] = self . _candidates_data_p # type: ignore
493
495
candidates .data = candidates_data .ctypes .data_as (llama_cpp .llama_token_data_p )
494
496
candidates .sorted = llama_cpp .c_bool (False )
495
497
candidates .size = llama_cpp .c_size_t (n_vocab )
You can’t perform that action at this time.
0 commit comments