8000 Fix logprob calculation. Fixes #134 · Stonelinks/llama-cpp-python@b6747f7 · GitHub
[go: up one dir, main page]

Skip to content

Commit b6747f7

Browse files
committed
Fix logprob calculation. Fixes abetlen#134
1 parent c088a2b commit b6747f7

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

llama_cpp/llama.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,7 @@ def _create_completion(
638638
for token in all_tokens
639639
]
640640
all_logprobs = [
641-
[Llama.logit_to_logprob(logit) for logit in row]
641+
Llama._logits_to_logprobs(row)
642642
for row in self.eval_logits
643643
]
644644
for token, token_str, logprobs_token in zip(
@@ -980,5 +980,7 @@ def token_bos() -> llama_cpp.llama_token:
980980
return llama_cpp.llama_token_bos()
981981

982982
@staticmethod
983-
def logit_to_logprob(x: float) -> float:
984-
return math.log(1.0 + math.exp(x))
983+
def logits_to_logprobs(logits: List[llama_cpp.c_float]) -> List[llama_cpp.c_float]:
984+
exps = [math.exp(float(x)) for x in logits]
985+
sum_exps = sum(exps)
986+
return [llama_cpp.c_float(math.log(x / sum_exps)) for x in exps]

0 commit comments

Comments
 (0)
0