@@ -1115,6 +1115,56 @@ def decode_batch(seq_sizes: List[int]):
1115
1115
else :
1116
1116
return output
1117
1117
1118
+ def _create_chunk (
1119
+ self ,
1120
+ completion_id : str ,
1121
+ created : int ,
1122
+ model_name : str ,
1123
+ text : str ,
1124
+ logprobs_or_none : Union [Optional [CompletionLogprobs ], None ],
1125
+ include_usage : bool ,
1126
+ index : int ,
1127
+ finish_reason : Union [str , None ],
1128
+ usage : Union [Dict [str , Any ], None ] = None ,
1129
+ ) -> CreateChatCompletionStreamResponse :
1130
+ """
1131
+ Create chunks for streaming API, depending on whether usage is requested or
1132
+ not they need (or don't need) an additional field
1133
+ """
1134
+
1135
+ if include_usage :
1136
+ token = {
1137
+ "id" : completion_id ,
1138
+ "object" : "text_completion" ,
1139
+ "created" : created ,
1140
+ "model" : model_name ,
1141
+ "choices" : [
1142
+ {
1143
+ "text" : text ,
1144
+ "index" : index ,
1145
+ "logprobs" : logprobs_or_none ,
1146
+ "finish_reason" : finish_reason ,
1147
+ },
1148
+ ],
1149
+ "usage" : usage ,
1150
+ }
1151
+ else :
1152
+ token = {
1153
+ "id" : completion_id ,
1154
+ "object" : "text_completion" ,
1155
+ "created" : created ,
1156
+ "model" : model_name ,
1157
+ "choices" : [
1158
+ {
1159
+ "text" : text ,
1160
+ "index" : index ,
1161
+ "logprobs" : logprobs_or_none ,
1162
+ "finish_reason" : finish_reason ,
1163
+ }
1164
+ ],
1165
+ }
1166
+ return token
1167
+
1118
1168
def _create_completion (
1119
1169
self ,
1120
1170
prompt : Union [str , List [int ]],
@@ -1132,6 +1182,7 @@ def _create_completion(
1132
1182
repeat_penalty : float = 1.0 ,
1133
1183
top_k : int = 40 ,
1134
1184
stream : bool = False ,
1185
+ stream_options : Optional [StreamOptions ] = None ,
1135
1186
seed : Optional [int ] = None ,
1136
1187
tfs_z : float = 1.0 ,
1137
1188
mirostat_mode : int = 0 ,
@@ -1362,6 +1413,11 @@ def logit_bias_processor(
1362
1413
break
1363
1414
1364
1415
if stream :
1416
+ if stream_options and "include_usage" in stream_options :
1417
+ include_usage = stream_options ["include_usage" ]
1418
+ else :
1419
+ include_usage = False
1420
+
1365
1421
remaining_tokens = completion_tokens [returned_tokens :]
1366
1422
remaining_text = self .detokenize (
1367
1423
remaining_tokens ,
@@ -1441,24 +1497,23 @@ def logit_bias_processor(
1441
1497
"top_logprobs" : [top_logprob ],
1442
1498
}
1443
1499
returned_tokens += 1
1444
- yield {
1445
- "id" : completion_id ,
1446
- "object" : "text_completion" ,
1447
- "created" : created ,
1448
- "model" : model_name ,
1449
- "choices" : [
1450
- {
1451
- "text" : self .detokenize (
1452
- [token ],
1453
- prev_tokens = prompt_tokens
1454
- + completion_tokens [:returned_tokens ],
1455
- ).decode ("utf-8" , errors = "ignore" ),
1456
- "index" : 0 ,
1457
- "logprobs" : logprobs_or_none ,
1458
- "finish_reason" : None ,
1459
- }
1460
- ],
1461
- }
1500
+ text = (
1501
+ self .detokenize (
1502
+ [token ],
1503
+ prev_tokens = prompt_tokens
1504
+ + completion_tokens [:returned_tokens ],
1505
+ ).decode ("utf-8" , errors = "ignore" ),
1506
+ )
1507
+ yield self ._create_chunk (
1508
+ completion_id = completion_id ,
1509
+ created = created ,
1510
+ model_name = model_name ,
1511
+ text = text ,
1512
+ finish_reason = None ,
1513
+ index = 0 ,
1514
+ logprobs_or_none = logprobs_or_none ,
1515
+ include_usage = include_usage ,
1516
+ )
1462
1517
else :
1463
1518
while len (remaining_tokens ) > 0 :
1464
1519
decode_success = False
@@ -1487,20 +1542,16 @@ def logit_bias_processor(
1487
1542
remaining_tokens = remaining_tokens [i :]
1488
1543
returned_tokens += i
1489
1544
1490
- yield {
1491
- "id" : completion_id ,
1492
- "object" : "text_completion" ,
1493
- "created" : created ,
1494
- "model" : model_name ,
1495
- "choices" : [
1496
- {
1497
- "text" : ts ,
1498
- "index" : 0 ,
1499
- "logprobs" : None ,
1500
- "finish_reason" : None ,
1501
- }
1502
- ],
1503
- }
1545
+ yield self ._create_chunk (
1546
+ index = 0 ,
1547
+ finish_reason = None ,
1548
+ completion_id = completion_id ,
1549
+ created = created ,
1550
+ model_name = model_name ,
1551
+ text = ts ,
1552
+ logprobs_or_none = None ,
1553
+ include_usage = include_usage ,
1554
+ )
1504
1555
1505
1556
if len (completion_tokens ) >= max_tokens :
1506
1557
text = self .detokenize (completion_tokens , prev_tokens = prompt_tokens )
@@ -1579,54 +1630,60 @@ def logit_bias_processor(
1579
1630
if token_end_position == end - 1 :
1580
1631
break
1581
1632
returned_tokens += 1
1582
- yield {
1583
- "id" : completion_id ,
1584
- "object" : "text_completion" ,
1585
- "created" : created ,
1586
- "model" : model_name ,
1587
- "choices" : [
1588
- {
1589
- "text" : last_text [
1590
- : len (last_text ) - (token_end_position - end )
1591
- ].decode ("utf-8" , errors = "ignore" ),
1592
- "index" : 0 ,
1593
- "logprobs" : logprobs_or_none ,
1594
- "finish_reason" : None ,
1595
- }
1596
- ],
1597
- }
1633
+ text = last_text [
1634
+ : len (last_text ) - (token_end_position - end )
1635
+ ].decode ("utf-8" , errors = "ignore" )
1636
+
1637
+ yield self ._create_chunk (
1638
+ completion_id = completion_id ,
1639
+ created = created ,
1640
+ model_name = model_name ,
1641
+ text = text ,
1642
+ logprobs_or_none = logprobs_or_none ,
1643
+ include_usage = include_usage ,
1644
+ index = 0 ,
1645
+ finish_reason = None ,
1646
+ )
1598
1647
break
1599
1648
returned_tokens += 1
1600
- yield {
1601
- "id" : completion_id ,
1602
- "object" : "text_completion" ,
1603
- "created" : created ,
1604
- "model" : model_name ,
1605
- "choices" : [
1606
- {
1607
- "text" : self .detokenize ([token ]).decode (
1608
- "utf-8" , errors = "ignore"
1609
- ),
1610
- "index" : 0 ,
1611
- "logprobs" : logprobs_or_none ,
1612
- "finish_reason" : None ,
1613
- }
1614
- ],
1615
- }
1616
- yield {
1617
- "id" : completion_id ,
1618
- "object" : "text_completion" ,
1619
- "created" : created ,
1620
- "model" : model_name ,
1621
- "choices" : [
1622
- {
1623
- "text" : "" ,
1624
- "index" : 0 ,
1625
- "logprobs" : None ,
1626
- "finish_reason" : finish_reason ,
1627
- }
1628
- ],
1629
- }
1649
+ text = self .detokenize ([token ]).decode ("utf-8" , errors = "ignore" )
1650
+ yield self ._create_chunk (
1651
+ completion_id = completion_id ,
1652
+ created = created ,
1653
+ model_name = model_name ,
1654
+ text = text ,
1655
+ logprobs_or_none = logprobs_or_none ,
1656
+ include_usage = include_usage ,
1657
+ index = 0 ,
1658
+ finish_reason = None ,
1659
+ )
1660
+ yield self ._create_chunk (
1661
+ completion_id = completion_id ,
1662
+ created = created ,
1663
+ model_name = model_name ,
1664
+ text = "" ,
1665
+ index = 0 ,
1666
+ logprobs_or_none = None ,
1667
+ include_usage = include_usage ,
1668
+ usage = None ,
1669
+ finish_reason = finish_reason )
1670
+
1671
+ if include_usage :
1672
+ yield self ._create_chunk (
1673
+ completion_id = completion_id ,
1674
+ created = created ,
1675
+ model_name = model_name ,
1676
+ text = "" ,
1677
+ logprobs_or_none = None ,
1678
+ include_usage = include_usage ,
1679
+ index = 0 ,
1680
+ finish_reason = None ,
1681
+ usage = {
1682
+ "prompt_tokens" : len (prompt_tokens ),
1683
+ "completion_tokens" : returned_tokens ,
1684
+ "total_tokens" : len (prompt_tokens ) + returned_tokens ,
1685
+ },
1686
+ )
1630
1687
if self .cache :
1631
1688
if self .verbose :
1632
1689
print ("Llama._create_completion: cache save" , file = sys .stderr )
@@ -1735,6 +1792,7 @@ def logit_bias_processor(
1735
1792
},
1736
1793
}
1737
1794
1795
+
1738
1796
def create_completion (
1739
1797
self ,
1740
1798
prompt : Union [str , List [int ]],
@@ -1752,6 +1810,7 @@ def create_completion(
1752
1810
repeat_penalty : float = 1.0 ,
1753
1811
top_k : int = 40 ,
1754
1812
stream : bool = False ,
1813
+ stream_options : Optional [StreamOptions ] = None ,
1755
1814
seed : Optional [int ] = None ,
1756
1815
tfs_z : float = 1.0 ,
1757
1816
mirostat_mode : int = 0 ,
@@ -1815,6 +1874,7 @@ def create_completion(
1815
1874
repeat_penalty = repeat_penalty ,
1816
1875
top_k = top_k ,
1817
1876
stream = stream ,
1877
+ stream_options = stream_options ,
1818
1878
seed = seed ,
1819
1879
tfs_z = tfs_z ,
1820
1880
mirostat_mode = mirostat_mode ,
@@ -1849,6 +1909,7 @@ def __call__(
1849
1909
repeat_penalty : float = 1.0 ,
1850
1910
top_k : int = 40 ,
1851
1911
stream : bool = False ,
1912
+ stream_options : Optional [StreamOptions ] = None ,
1852
1913
seed : Optional [int ] = None ,
1853
1914
tfs_z : float = 1.0 ,
1854
1915
mirostat_mode : int = 0 ,
@@ -1912,6 +1973,7 @@ def __call__(
1912
1973
repeat_penalty = repeat_penalty ,
1913
1974
top_k = top_k ,
1914
1975
stream = stream ,
1976
+ stream_options = stream_options ,
1915
1977
seed = seed ,
1916
1978
tfs_z = tfs_z ,
1917
1979
mirostat_mode = mirostat_mode ,
@@ -1937,6 +1999,7 @@ def create_chat_completion(
1937
1999
min_p : float = 0.05 ,
1938
2000
typical_p : float = 1.0 ,
1939
2001
stream : bool = False ,
2002
+ stream_options : Optional [StreamOptions ] = False ,
1940
2003
stop : Optional [Union [str , List [str ]]] = [],
1941
2004
seed : Optional [int ] = None ,
1942
2005
response_format : Optional [ChatCompletionRequestResponseFormat ] = None ,
@@ -2010,6 +2073,7 @@ def create_chat_completion(
2010
2073
logprobs = logprobs ,
2011
2074
top_logprobs = top_logprobs ,
2012
2075
stream = stream ,
2076
+ stream_options = stream_options ,
2013
2077
stop = stop ,
2014
2078
seed = seed ,
2015
2079
response_format = response_format ,
0 commit comments