8000 fix: Don't store scores internally unless logits_all=True. Reduces me… · hariag/llama-cpp-python@29afcfd · GitHub
[go: up one dir, main page]

Skip to content

Commit 29afcfd

Browse files
committed
fix: Don't store scores internally unless logits_all=True. Reduces memory requirements for large context. Closes abetlen#1542
1 parent 22cedad commit 29afcfd

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

llama_cpp/llama.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ def free_lora_adapter():
451451
self.n_tokens = 0
452452
self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc)
453453
self.scores: npt.NDArray[np.single] = np.ndarray(
454-
(n_ctx, self._n_vocab), dtype=np.single
454+
(n_ctx if logits_all == True else n_batch, self._n_vocab), dtype=np.single
455455
)
456456

457457
self._mirostat_mu = ctypes.c_float(
@@ -648,12 +648,14 @@ def eval(self, tokens: Sequence[int]):
648648
)
649649
self.scores[n_past : n_past + n_tokens, :].reshape(-1)[::] = logits
650650
else:
651-
rows = 1
652-
cols = self._n_vocab
653-
logits = np.ctypeslib.as_array(
654-
self._ctx.get_logits(), shape=(rows * cols,)
655-
)
656-
self.scores[n_past + n_tokens - 1, :].reshape(-1)[::] = logits
651+
# rows = 1
652+
# cols = self._n_vocab
653+
# logits = np.ctypeslib.as_array(
654+
# self._ctx.get_logits(), shape=(rows * cols,)
655+
# )
656+
# self.scores[n_past + n_tokens - 1, :].reshape(-1)[::] = logits
657+
# NOTE: Now that sampling is done inside the sampler, logits are only needed for logprobs which requires logits_all
658+
pass
657659
# Update n_tokens
658660
self.n_tokens += n_tokens
659661

0 commit comments

Comments
 (0)
0