8000 perf: Don't convert logprobs arrays to lists (#1021) · shawnx11/llama-cpp-python@6b2e0e0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6b2e0e0

Browse files
authored
perf: Don't convert logprobs arrays to lists (abetlen#1021)
1 parent 62944df commit 6b2e0e0

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

llama_cpp/llama.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1552,7 +1552,7 @@ def logit_bias_processor(
15521552
self.detokenize(completion_tokens[:returned_tokens])
15531553
)
15541554
token_offset = len(prompt_tokens) + returned_tokens
1555-
logits = self._scores[token_offset - 1, :].tolist()
1555+
logits = self._scores[token_offset - 1, :]
15561556
current_logprobs = Llama.logits_to_logprobs(logits)
15571557
sorted_logprobs = list(
15581558
sorted(
@@ -1671,7 +1671,7 @@ def logit_bias_processor(
16711671
self.detokenize(completion_tokens[:returned_tokens])
16721672
)
16731673
token_offset = len(prompt_tokens) + returned_tokens - 1
1674-
logits = self._scores[token_offset, :].tolist()
1674+
logits = self._scores[token_offset, :]
16751675
current_logprobs = Llama.logits_to_logprobs(logits)
16761676
sorted_logprobs = list(
16771677
sorted(
@@ -1785,9 +1785,8 @@ def logit_bias_processor(
17851785
self.detokenize([token]).decode("utf-8", errors="ignore")
17861786
for token in all_tokens
17871787
]
1788-
all_logprobs = [
1789-
Llama.logits_to_logprobs(row.tolist()) for row in self._scores
1790-
][token_offset:]
1788+
all_logprobs = Llama.logits_to_logprobs(self._scores)[token_offset:]
1789+
# TODO: may be able to change this loop to use np.take_along_dim
17911790
for token, token_str, logprobs_token in zip(
17921791
all_tokens, all_token_strs, all_logprobs
17931792
):
@@ -2282,7 +2281,7 @@ def token_nl(self) -> int:
22822281

22832282
@staticmethod
22842283
def logits_to_logprobs(
2285-
logits: Union[List, npt.NDArray[np.single]], axis: int = -1
2284+
logits: Union[npt.NDArray[np.single], List], axis: int = -1
22862285
) -> npt.NDArray[np.single]:
22872286
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.log_softmax.html
22882287
logits_maxs: np.ndarray = np.amax(logits, axis=axis, keepdims=True)
@@ -2293,7 +2292,7 @@ def logits_to_logprobs(
22932292
subtract_maxs = np.subtract(logits, logits_maxs, dtype=np.single)
22942293
exp = np.exp(subtract_maxs)
22952294
# Suppress warnings about log of zero
2296-
with np.errstate(divide='ignore'):
2295+
with np.errstate(divide="ignore"):
22972296
summed = np.sum(exp, axis=axis, keepdims=True)
22982297
out = np.log(summed)
22992298
return subtract_maxs - out

0 commit comments

Comments
 (0)
0