diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 3d5238b80..fe36b4060 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -103,10 +103,6 @@ class Settings(BaseSettings): default=None, description="TEMPORARY", ) - mul_mat_q: Optional[bool] = Field( - default=None, - description="TEMPORARY", - ) class ErrorResponse(TypedDict): @@ -258,7 +254,9 @@ async def custom_route_handler(request: Request) -> Response: try: return await original_route_handler(request) except Exception as exc: + json_body = await request.json() + try: if "messages" in json_body: # Chat completion @@ -269,13 +267,14 @@ async def custom_route_handler(request: Request) -> Response: CreateEmbeddingRequest, ] ] = CreateChatCompletionRequest(**json_body) + elif "prompt" in json_body: # Text completion body = CreateCompletionRequest(**json_body) else: # Embedding body = CreateEmbeddingRequest(**json_body) - except Exception: + except Exception as e: # Invalid request body body = None @@ -292,6 +291,27 @@ async def custom_route_handler(request: Request) -> Response: return custom_route_handler + +def aat_demo_workarround(plain_message): + system_request_message = { + 'role':'system', + 'content':"You are a helpful assistance. Provide helpful and informative responses in a concise and complete manner. Please avoid using conversational tags and only reply in full sentences. Ensure that your answers are presented directly and without the human of 'Human:' or '###'. Thank you for your cooperation" + } + user_request_message = { + 'role':'user', + 'content':'<>Here are some crop profile data: line name:bulk monarch; mean protein:22.851810373282948; mean starch:37.66; mean sol. sugars:4.381759207; total lipid:8.25615294614547 line name:bulk hattrick; mean protein:19.67188553; mean starch:38.7; mean sol. sugars:3.002967393; total lipid:7.658846896704566 line name:boundary; mean protein:21.012795213521386; mean starch:36.3; mean sol. sugars:3.526429066; total lipid:8.794638597328241 line name:ICC 08585; mean protein:18.26596777; mean starch:35.84; mean sol. sugars:3.801993787; total lipid:9.922753918466833 line name:ICC 10382; mean protein:19.77542927; mean starch:32.28; mean sol. sugars:4.29316582; total lipid:8.563684442485943 line name:ICC 10630; mean protein:12.88549951; mean starch:36.06; mean sol. sugars:4.912351782; total lipid:9.10236131510577 line name:ICC 03439; mean protein:17.06102627; mean starch:34.16; mean sol. sugars:3.248215327; total lipid:8.854573854458007 <> [INST] ' + plain_message + ' [/INST]' + } + + json_body = { + 'messages':[system_request_message, user_request_message], + 'max_tokens': 300, + 'temperature': 0.1, + 'top_p': 0.95 + } + + return json_body + + router = APIRouter(route_class=RouteErrorHandler) settings: Optional[Settings] = None @@ -305,9 +325,38 @@ def create_app(settings: Optional[Settings] = None): title="🦙 llama.cpp Python API", version="0.0.1", ) + + origins = [ + "http://localhost", + "http://localhost:8066", + "http://localhost:6108", + "http://velocity-ev", + "http://velocity-ev:8066", + "http://velocity-ev:6108", + "http://vizhead01-ev", + "http://vizhead01-ev:8066", + "http://vizhead01-ev:6108" + ] + + @app.middleware("http") + async def add_cors_headers(request, call_next): + origin = 'http://localhost:6108' + if 'origin' in request.headers: + origin = origin + ', ' + request.headers['origin'] + + response = await call_next(request) + response.headers["Access-Control-Allow-Origin"] = origin + response.headers["Access-Control-Allow-Credentials"] = "true" + response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS" + response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, cache-control,expires,pragma,x-magda-api-key,x-magda-api-key-id" + + return response + + app.add_middleware( CORSMiddleware, - allow_origins=["*"], + # allow_origins=["*"], + allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], @@ -701,7 +750,6 @@ class CreateChatCompletionRequest(BaseModel): } } - @router.post( "/v1/chat/completions", ) @@ -711,12 +759,14 @@ async def create_chat_completion( llama: llama_cpp.Llama = Depends(get_llama), settings: Settings = Depends(get_settings), ) -> llama_cpp.ChatCompletion: + exclude = { "n", "logit_bias", "logit_bias_type", "user", } + kwargs = body.model_dump(exclude=exclude) if body.logit_bias is not None: @@ -750,6 +800,77 @@ def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]: else: return iterator_or_completion +@router.post( + "/v1/chat/plaincompletion", +) +async def create_chat_plain_completion( + request: Request, + llama: llama_cpp.Llama = Depends(get_llama), + settings: Settings = Depends(get_settings), +): + ## NOTE: applied AI team work arround for demo + if request.url.path == '/v1/chat/plaincompletion' and request.headers['content-type'] == "text/plain": + req_body = await request.body() + req_body = str(req_body, encoding='utf-8') + json_body = aat_demo_workarround(req_body) + + # Chat completion + body: Optional[ + Union[ + CreateChatCompletionRequest, + CreateCompletionRequest, + CreateEmbeddingRequest, + ] + ] = CreateChatCompletionRequest(**json_body) + + + exclude = { + "n", + "logit_bias", + "logit_bias_type", + "user", + } + + kwargs = body.model_dump(exclude=exclude) + + if body.logit_bias is not None: + kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([ + make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type), + ]) + + iterator_or_completion: Union[llama_cpp.ChatCompletion, Iterator[ + llama_cpp.ChatCompletionChunk + ]] = await run_in_threadpool(llama.create_chat_completion, **kwargs) + + if isinstance(iterator_or_completion, Iterator): + # EAFP: It's easier to ask for forgiveness than permission + first_response = await run_in_threadpool(next, iterator_or_completion) + + # If no exception was raised from first_response, we can assume that + # the iterator is valid and we can use it to stream the response. + def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]: + yield first_response + yield from iterator_or_completion + + send_chan, recv_chan = anyio.create_memory_object_stream(10) + + return EventSourceResponse( + recv_chan, data_sender_callable=partial( # type: ignore + get_event_publisher, + request=request, + inner_send_chan=send_chan, + iterator=iterator(), + ) + ) + else: + response_text = iterator_or_completion['choices'][0]['message']['content'].strip() + response_text = response_text.replace('"', '') + #return iterator_or_completion + return response_text + + + + class ModelData(TypedDict): id: str