10000 Fix logits_to_logprobs for 2-D and 3-D logits (#1002) · qeleb/llama-cpp-python@5a89446 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5a89446

Browse files
authored
Fix logits_to_logprobs for 2-D and 3-D logits (abetlen#1002)
* Fix logits_to_logprobs for 2-D and 3-D logits * Set dtype to single * Test size
1 parent 534b1ea commit 5a89446

File tree

3 files changed

+44
-10
lines changed

3 files changed

+44
-10
lines changed

llama_cpp/llama.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,7 @@ def __init__(
771771
**kwargs, # type: ignore
772772
):
773773
"""Load a llama.cpp model from `model_path`.
774-
774+
775775
Examples:
776776
Basic usage
777777
@@ -2280,14 +2280,22 @@ def token_nl(self) -> int:
22802280
return self._model.token_nl()
22812281

22822282
@staticmethod
2283-
def logits_to_logprobs(logits: npt.NDArray[np.single]) -> npt.NDArray[np.single]:
2284-
maximum = np.max(logits)
2285-
tmp = np.subtract(logits, maximum, dtype=np.single)
2286-
np.exp(tmp, out=tmp)
2287-
normalizer = 1.0 / np.sum(tmp)
2288-
np.multiply(normalizer, tmp, out=tmp)
2289-
np.log(tmp, out=tmp)
2290-
return tmp
2283+
def logits_to_logprobs(
2284+
logits: Union[List, npt.NDArray[np.single]], axis: int = -1
2285+
) -> npt.NDArray[np.single]:
2286+
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.log_softmax.html
2287+
logits_maxs: np.ndarray = np.amax(logits, axis=axis, keepdims=True)
2288+
if logits_maxs.ndim > 0:
2289+
logits_maxs[~np.isfinite(logits_maxs)] = 0
2290+
elif not np.isfinite(logits_maxs):
2291+
logits_maxs = 0
2292+
subtract_maxs = np.subtract(logits, logits_maxs, dtype=np.single)
2293+
exp = np.exp(subtract_maxs)
2294+
# Suppress warnings about log of zero
2295+
with np.errstate(divide='ignore'):
2296+
summed = np.sum(exp, axis=axis, keepdims=True)
2297+
out = np.log(summed)
2298+
return subtract_maxs - out
22912299

22922300
@staticmethod
22932301
def longest_token_prefix(a: Sequence[int], b: Sequence[int]):

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,12 @@ server = [
3333
"fastapi>=0.100.0",
3434
"pydantic-settings>=2.0.1",
3535
"sse-starlette>=1.6.1",
36-
"starlette-context>=0.3.6,<0.4"
36+
"starlette-context>=0.3.6,<0.4",
3737
]
3838
test = [
3939
"pytest>=7.4.0",
4040
"httpx>=0.24.1",
41+
"scipy>=1.10",
4142
]
4243
dev = [
4344
"black>=23.3.0",

tests/test_llama.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import ctypes
22

3+
import numpy as np
34
import pytest
5+
from scipy.special import log_softmax
46

57
import llama_cpp
68

@@ -264,5 +266,28 @@ def test_llama_server():
264266
}
265267

266268

269+
@pytest.mark.parametrize(
270+
"size_and_axis",
271+
[
272+
((32_000,), -1), # last token's next-token logits
273+
((10, 32_000), -1), # many tokens' next-token logits, or batch of last tokens
274+
((4, 10, 32_000), -1), # batch of texts
275+
],
276+
)
277+
@pytest.mark.parametrize("convert_to_list", [True, False])
278+
def test_logits_to_logprobs(size_and_axis, convert_to_list: bool, atol: float = 1e-7):
279+
size, axis = size_and_axis
280+
logits: np.ndarray = -np.random.uniform(low=0, high=60, size=size)
281+
logits = logits.astype(np.single)
282+
if convert_to_list:
283+
# Currently, logits are converted from arrays to lists. This may change soon
284+
logits = logits.tolist()
285+
log_probs = llama_cpp.Llama.logits_to_logprobs(logits, axis=axis)
286+
log_probs_correct = log_softmax(logits, axis=axis)
287+
assert log_probs.dtype == np.single
288+
assert log_probs.shape == size
289+
assert np.allclose(log_probs, log_probs_correct, atol=atol)
290+
291+
267292
def test_llama_cpp_version():
268293
assert llama_cpp.__version__

0 commit comments

Comments
 (0)
0