8000 Refactor _create_completion · qeleb/llama-cpp-python@40f2293 · GitHub
[go: up one dir, main page]

Skip to content

Commit 40f2293

Browse files
committed
Refactor _create_completion
1 parent 7ae9a3e commit 40f2293

File tree

1 file changed

+105
-146
lines changed

1 file changed

+105
-146
lines changed

llama_cpp/llama.py

Lines changed: 105 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,79 @@ def logit_bias_processor(
780780

781781
if seed is not None:
782782
self._ctx.set_rng_seed(seed)
783+
784+
def _completion_stream_response(text: str, logprobs_or_none: Optional[CompletionLogprobs] = None, finish_reason: Optional[Literal["stop", "length"]] = None) -> CreateCompletionStreamResponse:
785+
return {
786+
"id": completion_id,
787+
"object": "text_completion",
788+
"created": created,
789+
"model": model_name,
790+
"choices": [
791+
{
792+
"text": text,
793+
"index": 0,
794+
"logprobs": logprobs_or_none,
795+
"finish_reason": finish_reason,
796+
}
797+
],
798+
}
799+
800+
def _completion_response(text: str, finish_reason: Literal["stop", "length"], logprobs_or_none: Optional[CompletionLogprobs] = None) -> CreateCompletionResponse:
801+
return {
802+
"id": completion_id,
803+
"object": "text_completion",
804+
"created": created,
805+
"model": model_name,
806+
"choices": [
807+
{
808+
"text": text,
809+
"index": 0,
810+
"logprobs": logprobs_or_none,
811+
"finish_reason": finish_reason,
812+
}
813+
],
814+
"usage": {
815+
"prompt_tokens": len(prompt_tokens),
816+
"completion_tokens": len(completion_tokens),
817+
"total_tokens": len(prompt_tokens) + len(completion_tokens),
818+
},
819+
}
820+
821+
def _logprobs_or_none(all_tokens: List[int], all_token_strs: List[str], all_logprobs: List[List[float]], text_offset: int) -> CompletionLogprobs:
822+
tokens: List[str] = []
823+
text_offsets: List[int] = []
824+
token_logprobs: List[Optional[float]] = []
825+
top_logprobs: List[Optional[Dict[str, float]]] = []
826+
827+
for token, token_str, token_logprob in zip(
828+
all_tokens, all_token_strs, all_logprobs
829+
):
830+
if token == self.token_bos():
831+
continue
832+
833+
text_offset += len(token_str)
834+
sorted_logprobs = list(
835+
sorted(
836+
zip(token_logprob, range(len(token_logprob))), reverse=True
837+
)
838+
)
839+
top_logprob = {
840+
self.detokenize([i]).decode("utf-8", errors="ignore"): logprob
841+
for logprob, i in sorted_logprobs[:logprobs]
842+
}
843+
top_logprob.update({token_str: token_logprob[int(token)]})
844+
845+
tokens.append(token_str)
846+
text_offsets.append(text_offset)
847+
token_logprobs.append(token_logprob[int(token)])
848+
top_logprobs.append(top_logprob)
849+
850+
return {
851+
"tokens": tokens,
852+
"text_offset": text_offsets,
853+
"token_logprobs": token_logprobs,
854+
"top_logprobs": top_logprobs,
855+
}
783856

784857
finish_reason = "length"
785858
multibyte_fix = 0
@@ -868,10 +941,10 @@ def logit_bias_processor(
868941
)
869942
token_offset = len(prompt_tokens) + returned_tokens
870943
logits = self._scores[token_offset - 1, :].tolist()
871-
current_logprobs = Llama.logits_to_logprobs(logits)
944+
token_logprob = Llama.logits_to_logprobs(logits)
872945
sorted_logprobs = list(
873946
sorted(
874-
zip(current_logprobs, range(len(current_logprobs))),
947+
zip(token_logprob, range(len(token_logprob))),
875948
reverse=True,
876949
)
877950
)
@@ -881,34 +954,22 @@ def logit_bias_processor(
881954
): logprob
882955
for logprob, i in sorted_logprobs[:logprobs]
883956
}
884-
top_logprob.update({token_str: current_logprobs[int(token)]})
957+
top_logprob.update({token_str: token_logprob[int(token)]})
885958
logprobs_or_none = {
886959
"tokens": [
887960
self.detokenize([token]).decode(
888961
"utf-8", errors="ignore"
889962
)
890963
],
891964
"text_offset": [text_offset],
892-
"token_logprobs": [current_logprobs[int(token)]],
965+
"token_logprobs": [token_logprob[int(token)]],
893966
"top_logprobs": [top_logprob],
894967
}
895968
returned_tokens += 1
896-
yield {
897-
"id": completion_id,
898-
"object": "text_completion",
899-
"created": created,
900-
"model": model_name,
901-
"choices": [
902-
{
903-
"text": self.detokenize([token]).decode(
904-
"utf-8", errors="ignore"
905-
),
906-
"index": 0,
907-
"logprobs": logprobs_or_none,
908-
"finish_reason": None,
909-
}
910-
],
911-
}
969+
yield _completion_stream_response(
970+
self.detokenize([token]).decode("utf-8", errors="ignore"),
971+
logprobs_or_none,
972+
)
912973
else:
913974
while len(remaining_tokens) > 0:
914975
decode_success = False
@@ -933,20 +994,7 @@ def logit_bias_processor(
933994
remaining_tokens = remaining_tokens[i:]
934995
returned_tokens += i
935996

936-
yield {
937-
"id": completion_id,
938-
"object": "text_completion",
939-
"created": created,
940-
"model": model_name,
941-
"choices": [
942-
{
943-
"text": ts,
944-
"index": 0,
945-
"logprobs": None,
946-
"finish_reason": None,
947-
}
948-
],
949-
}
997+
yield _completion_stream_response(text=ts)
950998

951999
if len(completion_tokens) >= max_tokens:
9521000
text = self.detokenize(completion_tokens)
@@ -986,25 +1034,22 @@ def logit_bias_processor(
9861034
self.detokenize(completion_tokens[:returned_tokens])
9871035
)
9881036
token_offset = len(prompt_tokens) + returned_tokens - 1
989-
logits = self._scores[token_offset, :].tolist()
990-
current_logprobs = Llama.logits_to_logprobs(logits)
1037+
token_logprob = Llama.logits_to_logprobs(self._scores[token_offset, :].tolist())
9911038
sorted_logprobs = list(
9921039
sorted(
993-
zip(current_logprobs, range(len(current_logprobs))),
1040+
zip(token_logprob, range(len(token_logprob))),
9941041
reverse=True,
9951042
)
9961043
)
9971044
top_logprob = {
9981045
self.detokenize([i]).decode("utf-8", errors="ignore"): logprob
9991046
for logprob, i in sorted_logprobs[:logprobs]
10001047
}
1001-
top_logprob.update({token_str: current_logprobs[int(token)]})
1048+
top_logprob.update({token_str: token_logprob[int(token)]})
10021049
logprobs_or_none = {
1003-
"tokens": [
1004-
self.detokenize([token]).decode("utf-8", errors="ignore")
1005-
],
1050+
"tokens": [token_str],
10061051
"text_offset": [text_offset],
1007-
"token_logprobs": [current_logprobs[int(token)]],
1052+
"token_logprobs": [token_logprob[int(token)]],
10081053
"top_logprobs": [top_logprob],
10091054
}
10101055

@@ -1013,54 +1058,17 @@ def logit_bias_processor(
10131058
if token_end_position == end - 1:
10141059
break
10151060
returned_tokens += 1
1016-
yield {
1017-
"id": completion_id,
1018-
"object": "text_completion",
1019-
"created": created,
1020-
"model": model_name,
1021-
"choices": [
1022-
{
1023-
"text": last_text[
1024-
: len(last_text) - (token_end_position - end)
1025-
].decode("utf-8", errors="ignore"),
1026-
"index": 0,
1027-
"logprobs": logprobs_or_none,
1028-
"finish_reason": None,
1029-
}
1030-
],
1031-
}
1061+
yield _completion_stream_response(
1062+
text=last_text[: len(last_text) - (token_end_position - end)].decode("utf-8", errors="ignore"), logprobs_or_none=logprobs_or_none
1063+
)
10321064
break
10331065
returned_tokens += 1
1034-
yield {
1035-
"id": completion_id,
1036-
"object": "text_completion",
1037-
"created": created,
1038-
"model": model_name,
1039-
"choices": [
1040-
{
1041-
"text": self.detokenize([token]).decode(
1042-
"utf-8", errors="ignore"
1043-
),
1044-
"index": 0,
1045-
"logprobs": logprobs_or_none,
1046-
"finish_reason": None,
1047-
}
1048-
],
1049-
}
1050-
yield {
1051-
"id": completion_id,
1052-
"object": "text_completion",
1053-
"created": created,
1054-
"model": model_name,
1055-
"choices": [
1056-
{
1057-
"text": "",
1058-
"index": 0,
1059-
"logprobs": None,
1060-
"finish_reason": finish_reason,
1061-
}
1062-
],
1063-
}
1066+
yield _completion_stream_response(
1067+
text=self.detokenize([token]).decode("utf-8", errors="ignore"), logprobs_or_none=logprobs_or_none
1068+
)
1069+
yield _completion_stream_response(
1070+
text=self.detokenize(completion_tokens[returned_tokens:]).decode("utf-8", errors="ignore"), finish_reason=finish_reason
1071+
)
10641072
if self.cache:
10651073
if self.verbose:
10661074
print("Llama._create_completion: cache save", file=sys.stderr)
@@ -1076,85 +1084,36 @@ def logit_bias_processor(
10761084
text_str = text.decode("utf-8", errors="ignore")
10771085

10781086
if echo:
1087+
assert isinstance(prompt, str)
10791088
text_str = prompt + text_str
10801089

10811090
if suffix is not None:
10821091
text_str = text_str + suffix
10831092

10841093
logprobs_or_none: Optional[CompletionLogprobs] = None
10851094
if logprobs is not None:
1095+
# Remove leading BOS token
1096+
all_tokens = prompt_tokens[1:] + completion_tokens if echo else completion_tokens
10861097
text_offset = 0 if echo else len(prompt)
10871098
token_offset = 0 if echo else len(prompt_tokens[1:])
1088-
text_offsets: List[int] = []
1089-
token_logprobs: List[Optional[float]] = []
1090-
tokens: List[str] = []
1091-
top_logprobs: List[Optional[Dict[str, float]]] = []
1092-
1093-
if echo:
1094-
# Remove leading BOS token
1095-
all_tokens = prompt_tokens[1:] + completion_tokens
1096-
else:
1097-
all_tokens = completion_tokens
1098-
10991099
all_token_strs = [
11001100
self.detokenize([token]).decode("utf-8", errors="ignore")
11011101
for token in all_tokens
11021102
]
11031103
all_logprobs = [
11041104
Llama.logits_to_logprobs(row.tolist()) for row in self._scores
11051105
][token_offset:]
1106-
for token, token_str, logprobs_token in zip(
1107-
all_tokens, all_token_strs, all_logprobs
1108-
):
1109-
if token == self.token_bos():
1110-
continue
1111-
text_offsets.append(text_offset)
1112-
text_offset += len(token_str)
1113-
tokens.append(token_str)
1114-
sorted_logprobs = list(
1115-
sorted(
1116-
zip(logprobs_token, range(len(logprobs_token))), reverse=True
1117-
)
1118-
)
1119-
token_logprobs.append(logprobs_token[int(token)])
1120-
top_logprob: Optional[Dict[str, float]] = {
1121-
self.detokenize([i]).decode("utf-8", errors="ignore"): logprob
1122-
for logprob, i in sorted_logprobs[:logprobs]
1123-
}
1124-
top_logprob.update({token_str: logprobs_token[int(token)]})
1125-
top_logprobs.append(top_logprob)
1106+
logprobs_or_none = _logprobs_or_none(
1107+
all_tokens, all_token_strs, all_logprobs, text_offset
1108+
)
11261109
# Weird idosincracy of the OpenAI API where
11271110
# token_logprobs and top_logprobs are null for
11281111
# the first token.
11291112
if echo and len(all_tokens) > 0:
1130-
token_logprobs[0] = None
1131-
top_logprobs[0] = None
1132-
logprobs_or_none = {
1133-
"tokens": tokens,
1134-
"text_offset": text_offsets,
1135-
"token_logprobs": token_logprobs,
1136-
"top_logprobs": top_logprobs,
1137-
}
1113+
logprobs_or_none["token_logprobs"][0] = None
1114+
logprobs_or_none["top_logprobs"][0] = None
11381115

1139-
yield {
1140-
"id": completion_id,
1141-
"object": "text_completion",
1142-
"created": created,
1143-
"model": model_name,
1144-
"choices": [
1145-
{
1146-
"text": text_str,
1147-
"index": 0,
1148-
"logprobs": logprobs_or_none,
1149-
"finish_reason": finish_reason,
1150-
}
1151-
],
1152-
"usage": {
1153-
"prompt_tokens": len(prompt_tokens),
1154-
"completion_tokens": len(completion_tokens),
1155-
"total_tokens": len(prompt_tokens) + len(completion_tokens),
1156-
},
1157-
}
1116+
yield _completion_response(text_str, finish_reason, logprobs_or_none)
11581117

11591118
def create_completion(
11601119
self,

0 commit comments

Comments
 (0)
0