8000 Add streaming and embedding endpoints to fastapi example · jooray/llama-cpp-python@ed6f2a0 · GitHub
[go: up one dir, main page]

Skip to content

Commit ed6f2a0

Browse files
committed
Add streaming and embedding endpoints to fastapi example
1 parent 0503e7f commit ed6f2a0

File tree

1 file changed

+57
-7
lines changed

1 file changed

+57
-7
lines changed

examples/fastapi_server.py

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
"""Example FastAPI server for llama.cpp.
22
"""
3-
from typing import List, Optional
3+
import json
4+
from typing import List, Optional, Iterator
45

5-
from llama_cpp import Llama
6+
import llama_cpp
67

78
from fastapi import FastAPI
8-
from pydantic import BaseModel, BaseSettings, Field
9+
from fastapi.middleware.cors import CORSMiddleware
10+
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
11+
from sse_starlette.sse import EventSourceResponse
912

1013

1114
class Settings(BaseSettings):
@@ -16,11 +19,24 @@ class Settings(BaseSettings):
1619
title="🦙 llama.cpp Python API",
1720
version="0.0.1",
1821
)
22+
app.add_middleware(
23+
CORSMiddleware,
24+
allow_origins=["*"],
25+
allow_credentials=True,
26+
allow_methods=["*"],
27+
allow_headers=["*"],
28+
)
1929
settings = Settings()
20-
llama = Llama(settings.model)
30+
llama = llama_cpp.Llama(
31+
settings.model,
32+
f16_kv=True,
33+
use_mlock=True,
34+
n_threads=6,
35+
n_batch=2048,
36+
)
2137

2238

23-
class CompletionRequest(BaseModel):
39+
class CreateCompletionRequest(BaseModel):
2440
prompt: str
2541
suffix: Optional[str] = Field(None)
2642
max_tokens: int = 16
@@ -31,6 +47,7 @@ class CompletionRequest(BaseModel):
3147
stop: List[str] = []
3248
repeat_penalty: float = 1.1
3349
top_k: int = 40
50+
stream: bool = False
3451

3552
class Config:
3653
schema_extra = {
@@ -41,6 +58,39 @@ class Config:
4158
}
4259

4360

44-
@app.post("/v1/completions")
45-
def completions(request: CompletionRequest):
61+
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
62+
63+
64+
@app.post(
65+
"/v1/completions",
66+
response_model=CreateCompletionResponse,
67+
)
68+
def create_completion(request: CreateCompletionRequest):
69+
if request.stream:
70+
chunks: Iterator[llama_cpp.CompletionChunk] = llama(**request.dict()) # type: ignore
71+
return EventSourceResponse(dict(data=json.dumps(chunk)) for chunk in chunks)
4672
return llama(**request.dict())
73+
74+
75+
class CreateEmbeddingRequest(BaseModel):
76+
model: Optional[str]
77+
input: str
78+
user: Optional[str]
79+
80+
class Config:
81+
schema_extra = {
82+
"example": {
83+
"input": "The food was delicious and the waiter...",
84+
}
85+
}
86+
87+
88+
CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
89+
90+
91+
@app.post(
92+
"/v1/embeddings",
93+
response_model=CreateEmbeddingResponse,
94+
)
95+
def create_embedding(request: CreateEmbeddingRequest):
96+
return llama.create_embedding(**request.dict())

0 commit comments

Comments
 (0)
0