8000 Replace logits_to_logprobs implementation with numpy equivalent to ll… · docmeth02/llama-cpp-python@ef22e47 · GitHub
[go: up one dir, main page]

Skip to content

Commit ef22e47

Browse files
authored
Replace logits_to_logprobs implementation with numpy equivalent to llama.cpp (abetlen#991)
See abetlen#990. This change makes the logits_to_logprobs function equivalent to the version in the llama.cpp repository. It uses numpy so it's much faster than the previous version.
1 parent ac35f68 commit ef22e47

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

llama_cpp/llama.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2280,10 +2280,14 @@ def token_nl(self) -> int:
22802280
return self._model.token_nl()
22812281

22822282
@staticmethod
2283-
def logits_to_logprobs(logits: List[float]) -> List[float]:
2284-
exps = [math.exp(float(x)) for x in logits]
2285-
sum_exps = sum(exps)
2286-
return [math.log(x / sum_exps) for x in exps]
2283+
def logits_to_logprobs(logits: npt.NDArray[np.single]) -> npt.NDArray[np.single]:
2284+
maximum = np.max(logits)
2285+
tmp = np.subtract(logits, maximum, dtype=np.single)
2286+
np.exp(tmp, out=tmp)
2287+
normalizer = 1.0 / np.sum(tmp)
2288+
np.multiply(normalizer, tmp, out=tmp)
2289+
np.log(tmp, out=tmp)
2290+
return tmp
22872291

22882292
@staticmethod
22892293
def longest_token_prefix(a: Sequence[int], b: Sequence[int]):

0 commit comments

Comments
 (0)
0