10000 Format · matthoffner/llama-cpp-python@4f2b5d0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4f2b5d0

Browse files
committed
Format
1 parent 34c505e commit 4f2b5d0

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

llama_cpp/llama.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def __init__(
324324
self._candidates = candidates
325325
self._token_nl = Llama.token_nl()
326326
self._token_eos = Llama.token_eos()
327-
self._candidates_data_id = np.arange(self._n_vocab, dtype=np.intc) # type: ignore
327+
self._candidates_data_id = np.arange(self._n_vocab, dtype=np.intc) # type: ignore
328328
self._candidates_data_p = np.zeros(self._n_vocab, dtype=np.single)
329329

330330
self.n_tokens = 0
@@ -445,8 +445,12 @@ def eval(self, tokens: Sequence[int]):
445445
# Save logits
446446
rows = n_tokens if self.params.logits_all else 1
447447
cols = self._n_vocab
448-
offset = 0 if self.params.logits_all else n_tokens - 1 # NOTE: Only save the last token logits if logits_all is False
449-
self.scores[self.n_tokens + offset: self.n_tokens + n_tokens, :].reshape(-1)[:] = llama_cpp.llama_get_logits(self.ctx)[:rows * cols]
448+
offset = (
449+
0 if self.params.logits_all else n_tokens - 1
450+
) # NOTE: Only save the last token logits if logits_all is False
451+
self.scores[self.n_tokens + offset : self.n_tokens + n_tokens, :].reshape(
452+
-1
453+
)[:] = llama_cpp.llama_get_logits(self.ctx)[: rows * cols]
450454
# Update n_tokens
451455
self.n_tokens += n_tokens
452456

@@ -491,7 +495,7 @@ def _sample(
491495
candidates_data = self._candidates_data
492496
candidates_data["id"][:] = self._candidates_data_id # type: ignore
493497
candidates_data["logit"][:] = logits
494-
candidates_data["p"][:] = self._candidates_data_p # type: ignore
498+
candidates_data["p"][:] = self._candidates_data_p # type: ignore
495499
candidates.data = candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p)
496500
candidates.sorted = llama_cpp.c_bool(False)
497501
candidates.size = llama_cpp.c_size_t(n_vocab)
@@ -537,7 +541,7 @@ def _sample(
537541
mirostat_mu = llama_cpp.c_float(2.0 * mirostat_tau.value)
538542
llama_cpp.llama_sample_temperature(
539543
ctx=self.ctx,
540-
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
544+
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
541545
temp=temp,
542546
)
543547
return llama_cpp.llama_sample_token_mirostat_v2(

0 commit comments

Comments
 (0)
0