8000 perf: avoid allocating new buffers during sampling · MobinX/llama-cpp-python@11eae75 · GitHub
[go: up one dir, main page]

Skip to content

Commit 11eae75

Browse files
committed
perf: avoid allocating new buffers during sampling
1 par
8000
ent 7887376 commit 11eae75

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

llama_cpp/llama.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,8 @@ def __init__(
324324
self._candidates = candidates
325325
self._token_nl = Llama.token_nl()
326326
self._token_eos = Llama.token_eos()
327+
self._candidates_data_id = np.arange(self._n_vocab, dtype=np.intc) # type: ignore
328+
self._candidates_data_p = np.zeros(self._n_vocab, dtype=np.single)
327329

328330
self.n_tokens = 0
329331
self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc)
@@ -487,9 +489,9 @@ def _sample(
487489
nl_logit = logits[self._token_nl]
488490
candidates = self._candidates
489491
candidates_data = self._candidates_data
490-
candidates_data["id"][:] = np.arange(n_vocab, dtype=np.intc) # type: ignore
492+
candidates_data["id"][:] = self._candidates_data_id # type: ignore
491493
candidates_data["logit"][:] = logits
492-
candidates_data["p"][:] = np.zeros(n_vocab, dtype=np.single)
494+
candidates_data["p"][:] = self._candidates_data_p # type: ignore
493495
candidates.data = candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p)
494496
candidates.sorted = llama_cpp.c_bool(False)
495497
candidates.size = llama_cpp.c_size_t(n_vocab)

0 commit comments

Comments
 (0)
0