2
2
3
3
import os
4
4
import json
5
+ import typing
5
6
import contextlib
6
7
7
8
from threading import Lock
@@ -88,11 +89,12 @@ def get_llama_proxy():
88
89
llama_outer_lock .release ()
89
90
90
91
91
- _ping_message_factory = None
92
+ _ping_message_factory : typing . Optional [ typing . Callable [[], bytes ]] = None
92
93
93
- def set_ping_message_factory (factory ):
94
- global _ping_message_factory
95
- _ping_message_factory = factory
94
+
95
+ def set_ping_message_factory (factory : typing .Callable [[], bytes ]):
96
+ global _ping_message_factory
97
+ _ping_message_factory = factory
96
98
97
99
98
100
def create_app (
@@ -155,20 +157,19 @@ def create_app(
155
157
156
158
async def get_event_publisher (
157
159
request : Request ,
158
- inner_send_chan : MemoryObjectSendStream ,
159
- iterator : Iterator ,
160
- on_complete = None ,
160
+ inner_send_chan : MemoryObjectSendStream [ typing . Any ] ,
161
+ iterator : Iterator [ typing . Any ] ,
162
+ on_complete : typing . Optional [ typing . Callable [[], None ]] = None ,
161
163
):
164
+ server_settings = next (get_server_settings ())
165
+ interrupt_requests = server_settings .interrupt_requests if server_settings else False
162
166
async with inner_send_chan :
163
167
try :
164
168
async for chunk in iterate_in_threadpool (iterator ):
165
169
await inner_send_chan .send (dict (data = json .dumps (chunk )))
166
170
if await request .is_disconnected ():
167
171
raise anyio .get_cancelled_exc_class ()()
168
- if (
169
- next (get_server_settings ()).interrupt_requests
170
- and llama_outer_lock .locked ()
171
- ):
172
+ if interrupt_requests and llama_outer_lock .locked ():
172
173
await inner_send_chan .send (dict (data = "[DONE]" ))
173
174
raise anyio .get_cancelled_exc_class ()()
174
175
await inner_send_chan .send (dict (data = "[DONE]" ))
@@ -268,6 +269,11 @@ async def create_completion(
268
269
llama_proxy = await run_in_threadpool (
269
270
lambda : exit_stack .enter_context (contextlib .contextmanager (get_llama_proxy )())
270
271
)
272
+ if llama_proxy is None :
273
+ raise HTTPException (
274
+ status_code = status .HTTP_503_SERVICE_UNAVAILABLE ,
275
+ detail = "Service is not available" ,
276
+ )
271
277
if isinstance (body .prompt , list ):
272
278
assert len (body .prompt ) <= 1
273
279
body .prompt = body .prompt [0 ] if len (body .prompt ) > 0 else ""
@@ -409,7 +415,7 @@ async def create_chat_completion(
409
415
{"role" : "system" , "content" : "You are a helpful assistant." },
410
416
{"role" : "user" , "content" : "Who won the world series in 2020" },
411
417
],
412
- "response_format" : { "type" : "json_object" }
418
+ "response_format" : {"type" : "json_object" },
413
419
},
414
420
},
415
421
"tool_calling" : {
@@ -434,15 +440,15 @@ async def create_chat_completion(
434
440
},
435
441
"required" : ["name" , "age" ],
436
442
},
437
- }
443
+ },
438
444
}
439
445
],
440
446
"tool_choice" : {
441
447
"type" : "function" ,
442
448
"function" : {
443
449
"name" : "User" ,
444
- }
445
- }
450
+ },
451
+ },
446
452
},
447
453
},
448
454
"logprobs" : {
@@ -454,7 +460,7 @@ async def create_chat_completion(
454
460
{"role" : "user" , "content" : "What is the capital of France?" },
455
461
],
456
462
"logprobs" : True ,
457
- "top_logprobs" : 10
463
+ "top_logprobs" : 10 ,
458
464
},
459
465
},
460
466
}
@@ -468,6 +474,11 @@ async def create_chat_completion(
468
474
llama_proxy = await run_in_threadpool (
469
475
lambda : exit_stack .enter_context (contextlib .contextmanager (get_llama_proxy )())
470
476
)
477
+ if llama_proxy is None :
478
+ raise HTTPException (
479
+ status_code = status .HTTP_503_SERVICE_UNAVAILABLE ,
480
+ detail = "Service is not available" ,
481
+ )
471
482
exclude = {
472
483
"n" ,
473
484
"logit_bias_type" ,
0 commit comments