10000 Fix sampling bug when logits_all=False · Zephyr800/llama-cpp-python@6f0b0b1 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6f0b0b1

Browse files
committed
Fix sampling bug when logits_all=False
1 parent d9b38e3 commit 6f0b0b1

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

llama_cpp/llama.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,16 +1029,16 @@ def eval(self, tokens: Sequence[int]):
10291029
)
10301030
self._ctx.decode(self._batch)
10311031
# Save tokens
1032-
self.input_ids[self.n_tokens : self.n_tokens + n_tokens] = batch
1032+
self.input_ids[n_past : n_past + n_tokens] = batch
10331033
# Save logits
1034-
rows = n_tokens if self.context_params.logits_all else 1
1034+
rows = n_tokens
10351035
cols = self._n_vocab
10361036
offset = (
10371037
0 if self.context_params.logits_all else n_tokens - 1
10381038
) # NOTE: Only save the last token logits if logits_all is False
1039-
self.scores[self.n_tokens + offset : self.n_tokens + n_tokens, :].reshape(
1039+
self.scores[n_past + offset : n_past + n_tokens, :].reshape(
10401040
-1
1041-
)[:] = self._ctx.get_logits()[: rows * cols]
1041+
)[:] = self._ctx.get_logits()[offset * cols: rows * cols]
10421042
# Update n_tokens
10431043
self.n_tokens += n_tokens
10441044

0 commit comments

Comments
 (0)
0