8000 fix(misc): Fix type errors · coderonion/llama-cpp-python@387d01d · GitHub
[go: up one dir, main page]

Skip to content

Commit 387d01d

Browse files
committed
fix(misc): Fix type errors
1 parent 4fb6fc1 commit 387d01d

File tree

1 file changed

+27
-16
lines changed

1 file changed

+27
-16
lines changed

llama_cpp/server/app.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import os
44
import json
5+
import typing
56
import contextlib
67

78
from threading import Lock
@@ -88,11 +89,12 @@ def get_llama_proxy():
8889
llama_outer_lock.release()
8990

9091

91-
_ping_message_factory = None
92+
_ping_message_factory: typing.Optional[typing.Callable[[], bytes]] = None
9293

93-
def set_ping_message_factory(factory):
94-
global _ping_message_factory
95-
_ping_message_factory = factory
94+
95+
def set_ping_message_factory(factory: typing.Callable[[], bytes]):
96+
global _ping_message_factory
97+
_ping_message_factory = factory
9698

9799

98100
def create_app(
@@ -155,20 +157,19 @@ def create_app(
155157

156158
async def get_event_publisher(
157159
request: Request,
158-
inner_send_chan: MemoryObjectSendStream,
159-
iterator: Iterator,
160-
on_complete=None,
160+
inner_send_chan: MemoryObjectSendStream[typing.Any],
161+
iterator: Iterator[typing.Any],
162+
on_complete: typing.Optional[typing.Callable[[], None]] = None,
161163
):
164+
server_settings = next(get_server_settings())
165+
interrupt_requests = server_settings.interrupt_requests if server_settings else False
162166
async with inner_send_chan:
163167
try:
164168
async for chunk in iterate_in_threadpool(iterator):
165169
await inner_send_chan.send(dict(data=json.dumps(chunk)))
166170
if await request.is_disconnected():
167171
raise anyio.get_cancelled_exc_class()()
168-
if (
169-
next(get_server_settings()).interrupt_requests
170-
and llama_outer_lock.locked()
171-
):
172+
if interrupt_requests and llama_outer_lock.locked():
172173
await inner_send_chan.send(dict(data="[DONE]"))
173174
raise anyio.get_cancelled_exc_class()()
174175
await inner_send_chan.send(dict(data="[DONE]"))
@@ -268,6 +269,11 @@ async def create_completion(
268269
llama_proxy = await run_in_threadpool(
269270
lambda: exit_stack.enter_context(contextlib.contextmanager(get_llama_proxy)())
270271
)
272+
if llama_proxy is None:
273+
raise HTTPException(
274+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
275+
detail="Service is not available",
276+
)
271277
if isinstance(body.prompt, list):
272278
assert len(body.prompt) <= 1
273279
body.prompt = body.prompt[0] if len(body.prompt) > 0 else ""
@@ -409,7 +415,7 @@ async def create_chat_completion(
409415
{"role": "system", "content": "You are a helpful assistant."},
410416
{"role": "user", "content": "Who won the world series in 2020"},
411417
],
412-
"response_format": { "type": "json_object" }
418+
"response_format": {"type": "json_object"},
413419
},
414420
},
415421
"tool_calling": {
@@ -434,15 +440,15 @@ async def create_chat_completion(
434440
},
435441
"required": ["name", "age"],
436442
},
437-
}
443+
},
438444
}
439445
],
440446
"tool_choice": {
441447
"type": "function",
442448
"function": {
443449
"name": "User",
444-
}
445-
}
450+
},
451+
},
446452
},
447453
},
448454
"logprobs": {
@@ -454,7 +460,7 @@ async def create_chat_completion(
454460
{"role": "user", "content": "What is the capital of France?"},
455461
],
456462
"logprobs": True,
457-
"top_logprobs": 10
463+
"top_logprobs": 10,
458464
},
459465
},
460466
}
@@ -468,6 +474,11 @@ async def create_chat_completion(
468474
llama_proxy = await run_in_threadpool(
469475
lambda: exit_stack.enter_context(contextlib.contextmanager(get_llama_proxy)())
470476
)
477+
if llama_proxy is None:
478+
raise HTTPException(
479+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
480+
detail="Service is not available",
481+
)
471482
exclude = {
472483
"n",
473484
"logit_bias_type",

0 commit comments

Comments
 (0)
0