8000 Fix logprobs for completions and implement for streaming logprobs. · HardSoft2023/llama-cpp-python@17d4271 · GitHub
[go: up one dir, main page]

Skip to content

Commit 17d4271

Browse files
committed
Fix logprobs for completions and implement for streaming logprobs.
1 parent a634a24 commit 17d4271

File tree

1 file changed

+103
-22
lines changed

1 file changed

+103
-22
lines changed

llama_cpp/llama.py

Lines changed: 103 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -710,22 +710,56 @@ def _create_completion(
710710
# We want to avoid yielding any characters from
711711
# the generated text if they are part of a stop
712712
# sequence.
713-
longest = 0
713+
first_stop_position = 0
714714
for s in stop_sequences:
715715
for i in range(len(s), 0, -1):
716716
if all_text.endswith(s[:i]):
717-
if i > longest:
718-
longest = i
717+
if i > first_stop_position:
718+
first_stop_position = i
719719
break
720720

721-
offset = 0
721+
token_end_position = 0
722722
remaining_tokens = completion_tokens[returned_tokens:]
723723
remaining_length = len(self.detokenize(remaining_tokens))
724724
for token in remaining_tokens:
725-
offset += len(self.detokenize([token]))
726-
# Check if stop sequence is not in the token
727-
if offset >= (remaining_length - longest - 1):
725+
token_end_position += len(self.detokenize([token]))
726+
# Check if stop sequence is in the token
727+
if token_end_position >= (remaining_length - first_stop_position - 1):
728728
break
729+
logprobs_or_none: Optional[CompletionLogprobs] = None
730+
if logprobs is not None:
731+
token_str = self.detokenize([token]).decode(
732+
"utf-8", errors="ignore"
733+
)
734+
text_offset = len(prompt) + len(
735+
self.detokenize(completion_tokens[:returned_tokens])
736+
)
737+
token_offset = len(prompt_tokens) + returned_tokens
738+
logits = self.eval_logits[token_offset - 1]
739+
current_logprobs = Llama.logits_to_logprobs(logits)
740+
sorted_logprobs = list(
741+
sorted(
742+
zip(current_logprobs, range(len(current_logprobs))),
743+
reverse=True,
744+
)
745+
)
746+
top_logprob = {
747+
self.detokenize([llama_cpp.llama_token(i)]).decode(
748+
"utf-8", errors="ignore"
749+
): logprob
750+
for logprob, i in sorted_logprobs[:logprobs]
751+
}
752+
top_logprob.update({token_str: current_logprobs[int(token)]})
753+
logprobs_or_none = {
754+
"tokens": [
755+
self.detokenize([token]).decode(
756+
"utf-8", errors="ignore"
757+
)
758+
],
759+
"text_offset": [text_offset],
760+
"token_logprobs": [sorted_logprobs[int(token)][0]],
761+
"top_logprobs": [top_logprob],
762+
}
729763
returned_tokens += 1
730764
yield {
731765
"id": completion_id,
@@ -738,7 +772,7 @@ def _create_completion(
738772
"utf-8", errors="ignore"
739773
),
740774
"index": 0,
741-
"logprobs": None,
775+
"logprobs": logprobs_or_none,
742776
"finish_reason": None,
743777
}
744778
],
@@ -766,13 +800,48 @@ def _create_completion(
766800
else:
767801
end = len(all_text)
768802

769-
offset = 0
803+
token_end_position = 0
770804
for token in remaining_tokens:
771-
offset += len(self.detokenize([token]))
772-
if offset >= end:
805+
token_end_position += len(self.detokenize([token]))
806+
807+
logprobs_or_none: Optional[CompletionLogprobs] = None
808+
if logprobs is not None:
809+
token_str = self.detokenize([token]).decode(
810+
"utf-8", errors="ignore"
811+
)
812+
text_offset = len(prompt) + len(
813+
self.detokenize(completion_tokens[:returned_tokens])
814+
)
815+
token_offset = len(prompt_tokens) + returned_tokens - 1
816+
logits = self.eval_logits[token_offset]
817+
current_logprobs = Llama.logits_to_logprobs(logits)
818+
sorted_logprobs = list(
819+
sorted(
820+
zip(current_logprobs, range(len(current_logprobs))),
821+
reverse=True,
822+
)
823+
)
824+
top_logprob = {
825+
self.detokenize([llama_cpp.llama_token(i)]).decode(
826+
"utf-8", errors="ignore"
827+
): logprob
828+
for logprob, i in sorted_logprobs[:logprobs]
829+
}
830+
top_logprob.update({token_str: current_logprobs[int(token)]})
831+
logprobs_or_none = {
832+
"tokens": [
833+
self.detokenize([token]).decode("utf-8", errors="ignore")
834+
],
835+
"text_offset": [text_offset],
836+
"token_logprobs": [sorted_logprobs[int(token)][0]],
837+
"top_logprobs": [top_logprob],
838+
}
839+
840+
if token_end_position >= end:
773841
last_text = self.detokenize([token])
774-
if offset == end - 1:
842+
if token_end_position == end - 1:
775843
break
844+
returned_tokens += 1
776845
yield {
777846
"id": completion_id,
778847
"object": "text_completion",
@@ -781,10 +850,10 @@ def _create_completion(
781850
"choices": [
782851
{
783852
"text": last_text[
784-
: len(last_text) - (offset - end)
853+
: len(last_text) - (token_end_position - end)
785854
].decode("utf-8", errors="ignore"),
786855
"index": 0,
787-
"logprobs": None,
856+
"logprobs": logprobs_or_none,
788857
"finish_reason": finish_reason,
789858
}
790859
],
@@ -802,7 +871,7 @@ def _create_completion(
802871
"utf-8", errors="ignore"
803872
),
804873
"index": 0,
805-
"logprobs": None,
874+
"logprobs": logprobs_or_none,
806875
"finish_reason": finish_reason
807876
if returned_tokens == len(completion_tokens)
808877
else None,
@@ -821,21 +890,27 @@ def _create_completion(
821890

822891
logprobs_or_none: Optional[CompletionLogprobs] = None
823892
if logprobs is not None:
824-
text_offset = 0
893+
text_offset = 0 if echo else len(prompt)
894+
token_offset = 0 if echo else len(prompt_tokens[1:])
825895
text_offsets: List[int] = []
826-
token_logprobs: List[float] = []
896+
token_logprobs: List[Optional[float]] = []
827897
tokens: List[str] = []
828-
top_logprobs: List[Dict[str, float]] = []
898+
top_logprobs: List[Optional[Dict[str, float]]] = []
899+
900+
if echo:
901+
# Remove leading BOS token
902+
all_tokens = prompt_tokens[1:] + completion_tokens
903+
else:
904+
all_tokens = completion_tokens
829905

830-
all_tokens = prompt_tokens + completion_tokens
831906
all_token_strs = [
832907
self.detokenize([token]).decode("utf-8", errors="ignore")
833908
for token in all_tokens
834909
]
835910
all_logprobs = [
836911
Llama.logits_to_logprobs(list(map(float, row)))
837912
for row in self.eval_logits
838-
]
913+
][token_offset:]
839914
for token, token_str, logprobs_token in zip(
840915
all_tokens, all_token_strs, all_logprobs
841916
):
@@ -848,14 +923,20 @@ def _create_completion(
848923
)
849924
)
850925
token_logprobs.append(sorted_logprobs[int(token)][0])
851-
top_logprob = {
926+
top_logprob: Optional[Dict[str, float]] = {
852927
self.detokenize([llama_cpp.llama_token(i)]).decode(
853928
"utf-8", errors="ignore"
854929
): logprob
855930
for logprob, i in sorted_logprobs[:logprobs]
856931
}
857-
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})
932+
top_logprob.update({token_str: logprobs_token[int(token)]})
858933
top_logprobs.append(top_logprob)
934+
# Weird idosincracy of the OpenAI API where
935+
# token_logprobs and top_logprobs are null for
936+
# the first token.
937+
if echo and len(all_tokens) > 0:
938+
token_logprobs[0] = None
939+
top_logprobs[0] = None
859940
logprobs_or_none = {
860941
"tokens": tokens,
861942
"text_offset": text_offsets,

0 commit comments

Comments
 (0)
0