8000 fix: apply https://github.com/abetlen/llama-cpp-python/pull/1552 · jeffmaury/llama-cpp-python@3670f1e · GitHub
[go: up one dir, main page]

Skip to content
8000

Commit 3670f1e

Browse files
committed
fix: apply abetlen#1552
Signed-off-by: Jeff MAURY <jmaury@redhat.com>
1 parent 7ecdd94 commit 3670f1e

File tree

5 files changed

+481
-345
lines changed

5 files changed

+481
-345
lines changed

llama_cpp/llama.py

Lines changed: 142 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,6 +1115,56 @@ def decode_batch(seq_sizes: List[int]):
11151115
else:
11161116
return output
11171117

1118+
def _create_chunk(
1119+
self,
1120+
completion_id: str,
1121+
created: int,
1122+
model_name: str,
1123+
text: str,
1124+
logprobs_or_none: Union[Optional[CompletionLogprobs], None],
1125+
include_usage: bool,
1126+
index: int,
1127+
finish_reason: Union[str, None],
1128+
usage: Union[Dict[str, Any], None] = None,
1129+
) -> CreateChatCompletionStreamResponse:
1130+
"""
1131+
Create chunks for streaming API, depending on whether usage is requested or
1132+
not they need (or don't need) an additional field
1133+
"""
1134+
1135+
if include_usage:
1136+
token = {
1137+
"id": completion_id,
1138+
"object": "text_completion",
1139+
"created": created,
1140+
"model": model_name,
1141+
"choices": [
1142+
{
1143+
"text": text,
1144+
"index": index,
1145+
"logprobs": logprobs_or_none,
1146+
"finish_reason": finish_reason,
1147+
},
1148+
],
1149+
"usage": usage,
1150+
}
1151+
else:
1152+
token = {
1153+
"id": completion_id,
1154+
"object": "text_completion",
1155+
"created": created,
1156+
"model": model_name,
1157+
"choices": [
1158+
{
1159+
"text": text,
1160+
"index": index,
1161+
"logprobs": logprobs_or_none,
1162+
"finish_reason": finish_reason,
1163+
}
1164+
],
1165+
}
1166+
return token
1167+
11181168
def _create_completion(
11191169
self,
11201170
prompt: Union[str, List[int]],
@@ -1132,6 +1182,7 @@ def _create_completion(
11321182
repeat_penalty: float = 1.0,
11331183
top_k: int = 40,
11341184
stream: bool = False,
1185+
stream_options: Optional[StreamOptions] = None,
11351186
seed: Optional[int] = None,
11361187
tfs_z: float = 1.0,
11371188
mirostat_mode: int = 0,
@@ -1362,6 +1413,11 @@ def logit_bias_processor(
13621413
break
13631414

13641415
if stream:
1416+
if stream_options and "include_usage" in stream_options:
1417+
include_usage = stream_options["include_usage"]
1418+
else:
1419+
include_usage = False
1420+
13651421
remaining_tokens = completion_tokens[returned_tokens:]
13661422
remaining_text = self.detokenize(
13671423
remaining_tokens,
@@ -1441,24 +1497,23 @@ def logit_bias_processor(
14411497
"top_logprobs": [top_logprob],
14421498
}
14431499
returned_tokens += 1
1444-
yield {
1445-
"id": completion_id,
1446-
"object": "text_completion",
1447-
"created": created,
1448-
"model": model_name,
1449-
"choices": [
1450-
{
1451-
"text": self.detokenize(
1452-
[token],
1453-
prev_tokens=prompt_tokens
1454-
+ completion_tokens[:returned_tokens],
1455-
).decode("utf-8", errors="ignore"),
1456-
"index": 0,
1457-
"logprobs": logprobs_or_none,
1458-
"finish_reason": None,
1459-
}
1460-
],
1461-
}
1500+
text = (
1501+
self.detokenize(
1502+
[token],
1503+
prev_tokens=prompt_tokens
1504+
+ completion_tokens[:returned_tokens],
1505+
).decode("utf-8", errors="ignore"),
1506+
)
1507+
yield self._create_chunk(
1508+
completion_id=completion_id,
1509+
created=created,
1510+
model_name=model_name,
1511+
text=text,
1512+
finish_reason=None,
1513+
index=0,
1514+
logprobs_or_none=logprobs_or_none,
1515+
include_usage=include_usage,
1516+
)
14621517
else:
14631518
while len(remaining_tokens) > 0:
14641519
decode_success = False
@@ -1487,20 +1542,16 @@ def logit_bias_processor(
14871542
remaining_tokens = remaining_tokens[i:]
14881543
returned_tokens += i
14891544

1490-
yield {
1491-
"id": completion_id,
1492-
"object": "text_completion",
1493-
"created": created,
1494-
"model": model_name,
1495-
"choices": [
1496-
{
1497-
"text": ts,
1498-
"index": 0,
1499-
"logprobs": None,
1500-
"finish_reason": None,
1501-
}
1502-
],
1503-
}
1545+
yield self._create_chunk(
1546+
index=0,
1547+
finish_reason=None,
1548+
completion_id=completion_id,
1549+
created=created,
1550+
model_name=model_name,
1551+
text=ts,
1552+
logprobs_or_none=None,
1553+
include_usage=include_usage,
1554+
)
15041555

15051556
if len(completion_tokens) >= max_tokens:
15061557
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
@@ -1579,54 +1630,60 @@ def logit_bias_processor(
15791630
if token_end_position == end - 1:
15801631
break
15811632
returned_tokens += 1
1582-
yield {
1583-
"id": completion_id,
1584-
"object": "text_completion",
1585-
"created": created,
1586-
"model": model_name,
1587-
"choices": [
1588-
{
1589-
"text": last_text[
1590-
: len(last_text) - (token_end_position - end)
1591-
].decode("utf-8", errors="ignore"),
1592-
"index": 0,
1593-
"logprobs": logprobs_or_none,
1594-
"finish_reason": None,
1595-
}
1596-
],
1597-
}
1633+
text = last_text[
1634+
: len(last_text) - (token_end_position - end)
1635+
].decode("utf-8", errors="ignore")
1636+
1637+
yield self._create_chunk(
1638+
completion_id=completion_id,
1639+
created=created,
1640+
model_name=model_name,
1641+
text=text,
1642+
logprobs_or_none=logprobs_or_none,
1643+
include_usage=include_usage,
1644+
index=0,
1645+
finish_reason=None,
1646+
)
15981647
break
15991648
returned_tokens += 1
1600-
yield {
1601-
"id": completion_id,
1602-
"object": "text_completion",
1603-
"created": created,
1604-
"model": model_name,
1605-
"choices": [
1606-
{
1607-
"text": self.detokenize([token]).decode(
1608-
"utf-8", errors="ignore"
1609-
),
1610-
"index": 0,
1611-
"logprobs": logprobs_or_none,
1612-
"finish_reason": None,
1613-
}
1614-
],
1615-
}
1616-
yield {
1617-
"id": completion_id,
1618-
"object": "text_completion",
1619-
"created": created,
1620-
"model": model_name,
1621-
"choices": [
1622-
{
1623-
"text": "",
1624-
"index": 0,
1625-
"logprobs": None,
1626-
"finish_reason": finish_reason,
1627-
}
1628-
],
1629-
}
1649+
text = self.detokenize([token]).decode("utf-8", errors="ignore")
1650+
yield self._create_chunk(
1651+
completion_id=completion_id,
1652+
created=created,
1653+
model_name=model_name,
1654+
text=text,
1655+
logprobs_or_none=logprobs_or_none,
1656+
include_usage=include_usage,
1657+
index=0,
1658+
finish_reason=None,
1659+
)
1660+
yield self._create_chunk(
1661+
completion_id= completion_id,
1662+
created= created,
1663+
model_name=model_name,
1664+
text="",
1665+
index=0,
1666+
logprobs_or_none= None,
1667+
include_usage=include_usage,
1668+
usage=None,
1669+
finish_reason=finish_reason)
1670+
1671+
if include_usage:
1672+
yield self._create_chunk(
1673+
completion_id=completion_id,
1674+
created=created,
1675+
model_name=model_name,
1676+
text="",
1677+
logprobs_or_none=None,
1678+
include_usage=include_usage,
1679+
index=0,
1680+
finish_reason=None,
1681+
usage={
1682+
"prompt_tokens": len(prompt_tokens),
1683+
"completion_tokens": returned_tokens,
1684+
"total_tokens": len(prompt_tokens) + returned_tokens,
1685+
},
1686+
)
16301687
if self.cache:
16311688
if self.verbose:
16321689
print("Llama._create_completion: cache save", file=sys.stderr)
@@ -1735,6 +1792,7 @@ def logit_bias_processor(
17351792
},
17361793
}
17371794

1795+
17381796
def create_completion(
17391797
self,
17401798
prompt: Union[str, List[int]],
@@ -1752,6 +1810,7 @@ def create_completion(
17521810
repeat_penalty: float = 1.0,
17531811
top_k: int = 40,
17541812
stream: bool = False,
1813+
stream_options: Optional[StreamOptions] = None,
17551814
seed: Optional[int] = None,
17561815
tfs_z: float = 1.0,
17571816
mirostat_mode: int = 0,
@@ -1815,6 +1874,7 @@ def create_completion(
18151874
repeat_penalty=repeat_penalty,
18161875
top_k=top_k,
18171876
stream=stream,
1877+
stream_options=stream_options,
18181878
seed=seed,
18191879
tfs_z=tfs_z,
18201880
mirostat_mode=mirostat_mode,
@@ -1849,6 +1909,7 @@ def __call__(
18491909
repeat_penalty: float = 1.0,
18501910
top_k: int = 40,
18511911
stream: bool = False,
1912+
stream_options: Optional[StreamOptions] = None,
18521913
seed: Optional[int] = None,
18531914
tfs_z: float = 1.0,
18541915
mirostat_mode: int = 0,
@@ -1912,6 +1973,7 @@ def __call__(
19121973
repeat_penalty=repeat_penalty,
19131974
top_k=top_k,
19141975
stream=stream,
1976+
stream_options=stream_options,
19151977
seed=seed,
19161978
tfs_z=tfs_z,
19171979
mirostat_mode=mirostat_mode,
@@ -1937,6 +1999,7 @@ def create_chat_completion(
19371999
min_p: float = 0.05,
19382000
typical_p: float = 1.0,
19392001
stream: bool = False,
2002+
stream_options: Optional[StreamOptions] = False,
19402003
stop: Optional[Union[str, List[str]]] = [],
19412004
seed: Optional[int] = None,
19422005
response_format: Optional[ChatCompletionRequestResponseFormat] = None,
@@ -2010,6 +2073,7 @@ def create_chat_completion(
20102073
logprobs=logprobs,
20112074
top_logprobs=top_logprobs,
20122075
stream=stream,
2076+
stream_options=stream_options,
20132077
stop=stop,
20142078
seed=seed,
20152079
response_format=response_format,

0 commit comments

Comments
 (0)
0