8000 Implement openai api compatible authentication (#1010) · cyberjon/llama-cpp-python@33cc623 · GitHub
[go: up one dir, main page]

Skip to content

Commit 33cc623

Browse files
authored
Implement openai api compatible authentication (abetlen#1010)
1 parent 788394c commit 33cc623

File tree

1 file changed

+36
-2
lines changed

1 file changed

+36
-2
lines changed

llama_cpp/server/app.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414
import anyio
1515
from anyio.streams.memory import MemoryObjectSendStream
1616
from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
17-
from fastapi import Depends, FastAPI, APIRouter, Request, Response
17+
from fastapi import Depends, FastAPI, APIRouter, Request, Response, HTTPException, status
1818
from fastapi.middleware import Middleware
1919
from fastapi.middleware.cors import CORSMiddleware
2020
from fastapi.responses import JSONResponse
2121
from fastapi.routing import APIRoute
22+
from fastapi.security import HTTPBearer
2223
from pydantic import BaseModel, Field
2324
from pydantic_settings import BaseSettings
2425
from sse_starlette.sse import EventSourceResponse
@@ -163,6 +164,10 @@ class Settings(BaseSettings):
163164
default=True,
164165
description="Whether to interrupt requests when a new request is received.",
165166
)
167+
api_key: Optional[str] = Field(
168+
default=None,
169+
description="API key for authentication. If set all requests need to be authenticated."
170+
)
166171

167172

168173
class ErrorResponse(TypedDict):
@@ -314,6 +319,9 @@ async def custom_route_handler(request: Request) -> Response:
314319
elapsed_time_ms = int((time.perf_counter() - start_sec) * 1000)
315320
response.headers["openai-processing-ms"] = f"{elapsed_time_ms}"
316321
return response
322+
except HTTPException as unauthorized:
323+
# api key check failed
324+
raise unauthorized
317325
except Exception as exc:
318326
json_body = await request.json()
319327
try:
@@ -658,6 +666,27 @@ def _logit_bias_tokens_to_input_ids(
658666
return to_bias
659667

660668

669+
# Setup Bearer authentication scheme
670+
bearer_scheme = HTTPBearer(auto_error=False)
671+
672+
673+
async def authenticate(settings: Settings = Depends(get_settings), authorization: Optional[str] = Depends(bearer_scheme)):
674+
# Skip API key check if it's not set in settings
675+
if settings.api_key is None:
676+
return True
677+
678+
# check bearer credentials against the api_key
679+
if authorization and authorization.credentials == settings.api_key:
680+
# api key is valid
< 10000 /td>
681+
return authorization.credentials
682+
683+
# raise http error 401
684+
raise HTTPException(
685+
status_code=status.HTTP_401_UNAUTHORIZED,
686+
detail="Invalid API key",
687+
)
688+
689+
661690
@router.post(
662691
"/v1/completions",
663692
summary="Completion"
@@ -667,6 +696,7 @@ async def create_completion(
667696
request: Request,
668697
body: CreateCompletionRequest,
669698
llama: llama_cpp.Llama = Depends(get_llama),
699+
authenticated: str = Depends(authenticate),
670700
) -> llama_cpp.Completion:
671701
if isinstance(body.prompt, list):
672702
assert len(body.prompt) <= 1
@@ -740,7 +770,9 @@ class CreateEmbeddingRequest(BaseModel):
740770
summary="Embedding"
741771
)
742772
async def create_embedding(
743-
request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama)
773+
request: CreateEmbeddingRequest,
774+
llama: llama_cpp.Llama = Depends(get_llama),
775+
authenticated: str = Depends(authenticate),
744776
):
745777
return await run_in_threadpool(
746778
llama.create_embedding, **request.model_dump(exclude={"user"})
@@ -834,6 +866,7 @@ async def create_chat_completion(
834866
body: CreateChatCompletionRequest,
835867
llama: llama_cpp.Llama = Depends(get_llama),
836868
settings: Settings = Depends(get_settings),
869+
authenticated: str = Depends(authenticate),
837870
) -> llama_cpp.ChatCompletion:
838871
exclude = {
839872
"n",
@@ -895,6 +928,7 @@ class ModelList(TypedDict):
895928
@router.get("/v1/models", summary="Models")
896929
async def get_models(
897930
settings: Settings = Depends(get_settings),
931+
authenticated: str = Depends(authenticate),
898932
) -> ModelList:
899933
assert llama is not None
900934
return {

0 commit comments

Comments
 (0)
0