14
14
import anyio
15
15
from anyio .streams .memory import MemoryObjectSendStream
16
16
from starlette .concurrency import run_in_threadpool , iterate_in_threadpool
17
- from fastapi import Depends , FastAPI , APIRouter , Request , Response
17
+ from fastapi import Depends , FastAPI , APIRouter , Request , Response , HTTPException , status
18
18
from fastapi .middleware import Middleware
19
19
from fastapi .middleware .cors import CORSMiddleware
20
20
from fastapi .responses import JSONResponse
21
21
from fastapi .routing import APIRoute
22
+ from fastapi .security import HTTPBearer
22
23
from pydantic import BaseModel , Field
23
24
from pydantic_settings import BaseSettings
24
25
from sse_starlette .sse import EventSourceResponse
@@ -163,6 +164,10 @@ class Settings(BaseSettings):
163
164
default = True ,
164
165
description = "Whether to interrupt requests when a new request is received." ,
165
166
)
167
+ api_key : Optional [str ] = Field (
168
+ default = None ,
169
+ description = "API key for authentication. If set all requests need to be authenticated."
170
+ )
166
171
167
172
168
173
class ErrorResponse (TypedDict ):
@@ -314,6 +319,9 @@ async def custom_route_handler(request: Request) -> Response:
314
319
elapsed_time_ms = int ((time .perf_counter () - start_sec ) * 1000 )
315
320
response .headers ["openai-processing-ms" ] = f"{ elapsed_time_ms } "
316
321
return response
322
+ except HTTPException as unauthorized :
323
+ # api key check failed
324
+ raise unauthorized
317
325
except Exception as exc :
318
326
json_body = await request .json ()
319
327
try :
@@ -658,6 +666,27 @@ def _logit_bias_tokens_to_input_ids(
658
666
return to_bias
659
667
660
668
669
+ # Setup Bearer authentication scheme
670
+ bearer_scheme = HTTPBearer (auto_error = False )
671
+
672
+
673
+ async def authenticate (settings : Settings = Depends (get_settings ), authorization : Optional [str ] = Depends (bearer_scheme )):
674
+ # Skip API key check if it's not set in settings
675
+ if settings .api_key is None :
676
+ return True
677
+
678
+ # check bearer credentials against the api_key
679
+ if authorization and authorization .credentials == settings .api_key :
680
+ # api key is valid
<
10000
/td>
681
+ return authorization .credentials
682
+
683
+ # raise http error 401
684
+ raise HTTPException (
685
+ status_code = status .HTTP_401_UNAUTHORIZED ,
686
+ detail = "Invalid API key" ,
687
+ )
688
+
689
+
661
690
@router .post (
662
691
"/v1/completions" ,
663
692
summary = "Completion"
@@ -667,6 +696,7 @@ async def create_completion(
667
696
request : Request ,
668
697
body : CreateCompletionRequest ,
669
698
llama : llama_cpp .Llama = Depends (get_llama ),
699
+ authenticated : str = Depends (authenticate ),
670
700
) -> llama_cpp .Completion :
671
701
if isinstance (body .prompt , list ):
672
702
assert len (body .prompt ) <= 1
@@ -740,7 +770,9 @@ class CreateEmbeddingRequest(BaseModel):
740
770
summary = "Embedding"
741
771
)
742
772
async def create_embedding (
743
- request : CreateEmbeddingRequest , llama : llama_cpp .Llama = Depends (get_llama )
773
+ request : CreateEmbeddingRequest ,
774
+ llama : llama_cpp .Llama = Depends (get_llama ),
775
+ authenticated : str = Depends (authenticate ),
744
776
):
745
777
return await run_in_threadpool (
746
778
llama .create_embedding , ** request .model_dump (exclude = {"user" })
@@ -834,6 +866,7 @@ async def create_chat_completion(
834
866
body : CreateChatCompletionRequest ,
835
867
llama : llama_cpp .Llama = Depends (get_llama ),
836
868
settings : Settings = Depends (get_settings ),
869
+ authenticated : str = Depends (authenticate ),
837
870
) -> llama_cpp .ChatCompletion :
838
871
exclude = {
839
872
"n" ,
@@ -895,6 +928,7 @@ class ModelList(TypedDict):
895
928
@router .get ("/v1/models" , summary = "Models" )
896
929
async def get_models (
897
930
settings : Settings = Depends (get_settings ),
931
+ authenticated : str = Depends (authenticate ),
898
932
) -> ModelList :
899
933
assert llama is not None
900
934
return {
0 commit comments