8000 Add server as a subpackage · lapnd/llama-cpp-python@44448fb · GitHub
[go: up one dir, main page]

Skip to content

Commit 44448fb

Browse files
committed
Add server as a subpackage
1 parent e1b5b9b commit 44448fb

File tree

2 files changed

+268
-1
lines changed

2 files changed

+268
-1
lines changed

llama_cpp/server/__main__.py

Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
"""Example FastAPI server for llama.cpp.
2+
3+
To run this example:
4+
5+
```bash
6+
pip install fastapi uvicorn sse-starlette
7+
export MODEL=../models/7B/...
8+
uvicorn fastapi_server_chat:app --reload
9+
```
10+
11+
Then visit http://localhost:8000/docs to see the interactive API docs.
12+
13+
"""
14+
import os
15+
import json
16+
from typing import List, Optional, Literal, Union, Iterator, Dict
17+
from typing_extensions import TypedDict
18+
19+
import llama_cpp
20+
21+
from fastapi import FastAPI
22+
from fastapi.middleware.cors import CORSMiddleware
23+
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
24+
from sse_starlette.sse import EventSourceResponse
25+
26+
27+
class Settings(BaseSettings):
28+
model: str
29+
n_ctx: int = 2048
30+
n_batch: int = 2048
31+
n_threads: int = os.cpu_count() or 1
32+
f16_kv: bool = True
33+
use_mlock: bool = True
34+
embedding: bool = True
35+
last_n_tokens_size: int = 64
36+
37+
38+
app = FastAPI(
39+
title="🦙 llama.cpp Python API",
40+
version="0.0.1",
41+
)
42+
app.add_middleware(
43+
CORSMiddleware,
44+
allow_origins=["*"],
45+
allow_credentials=True,
46+
allow_methods=["*"],
47+
allow_headers=["*"],
48+
)
49+
settings = Settings()
50+
llama = llama_cpp.Llama(
51+
settings.model,
52+
f16_kv=settings.f16_kv,
53+
use_mlock=settings.use_mlock,
54+
embedding=settings.embedding,
55+
n_threads=settings.n_threads,
56+
n_batch=settings.n_batch,
57+
n_ctx=settings.n_ctx,
58+
last_n_tokens_size=settings.last_n_tokens_size,
59+
)
60+
61+
62+
class CreateCompletionRequest(BaseModel):
63+
prompt: str
64+
suffix: Optional[str] = Field(None)
65+
max_tokens: int = 16
66+
temperature: float = 0.8
67+
top_p: float = 0.95
68+
echo: bool = False
69+
stop: List[str] = []
70+
stream: bool = False
71+
72+
# ignored or currently unsupported
73+
model: Optional[str] = Field(None)
74+
n: Optional[int] = 1
75+
logprobs: Optional[int] = Field(None)
76+
presence_penalty: Optional[float] = 0
77+
frequency_penalty: Optional[float] = 0
78+
best_of: Optional[int] = 1
79+
logit_bias: Optional[Dict[str, float]] = Field(None)
80+
user: Optional[str] = Field(None)
81+
82+
# llama.cpp specific parameters
83+
top_k: int = 40
84+
repeat_penalty: float = 1.1
85+
86+
class Config:
87+
schema_extra = {
88+
"example": {
89+
"prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n",
90+
"stop": ["\n", "###"],
91+
}
92+
}
93+
94+
95+
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
96+
97+
98+
@app.post(
99+
"/v1/completions",
100+
response_model=CreateCompletionResponse,
101+
)
102+
def create_completion(request: CreateCompletionRequest):
103+
if request.stream:
104+
chunks: Iterator[llama_cpp.CompletionChunk] = llama(**request.dict()) # type: ignore
105+
return EventSourceResponse(dict(data=json.dumps(chunk)) for chunk in chunks)
106+
return llama(
107+
**request.dict(
108+
exclude={
109+
"model",
110+
"n",
111+
"logprobs",
112+
"frequency_penalty",
113+
"presence_penalty",
114+
"best_of",
115+
"logit_bias",
116+
"user",
117+
}
118+
)
119+
)
120+
121+
122+
class CreateEmbeddingRequest(BaseModel):
123+
model: Optional[str]
124+
input: str
125+
user: Optional[str]
126+
127+
class Config:
128+
schema_extra = {
129+
"example": {
130+
"input": "The food was delicious and the waiter...",
131+
}
132+
}
133+
134+
135+
CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
136+
137+
138+
@app.post(
139+
"/v1/embeddings",
140+
response_model=CreateEmbeddingResponse,
141+
)
142+
def create_embedding(request: CreateEmbeddingRequest):
143+
return llama.create_embedding(**request.dict(exclude={"model", "user"}))
144+
145+
146+
class ChatCompletionRequestMessage(BaseModel):
147+
role: Union[Literal["system"], Literal["user"], Literal["assistant"]]
148+
content: str
149+
user: Optional[str] = None
150+
151+
152+
class CreateChatCompletionRequest(BaseModel):
153+
model: Optional[str]
154+
messages: List[ChatCompletionRequestMessage]
155+
temperature: float = 0.8
156+
top_p: float = 0.95
157+
stream: bool = False
158+
stop: List[str] = []
159+
max_tokens: int = 128
160+
161+
# ignored or currently unsupported
162+
model: Optional[str] = Field(None)
163+
n: Optional[int] = 1
164+
presence_penalty: Optional[float] = 0
165+
frequency_penalty: Optional[float] = 0
166+
logit_bias: Optional[Dict[str, float]] = Field(None)
167+
user: Optional[str] = Field(None)
168+
169+
# llama.cpp specific parameters
170+
repeat_penalty: float = 1.1
171+
172+
class Config:
173+
schema_extra = {
174+
"example": {
175+
"messages": [
176+
ChatCompletionRequestMessage(
177+
role="system", content="You are a helpful assistant."
178+
),
179+
ChatCompletionRequestMessage(
180+
role="user", content="What is the capital of France?"
181+
),
182+
]
183+
}
184+
}
185+
186+
187+
CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion)
188+
189+
190+
@app.post(
191+
"/v1/chat/completions",
192+
response_model=CreateChatCompletionResponse,
193+
)
194+
async def create_chat_completion(
195+
request: CreateChatCompletionRequest,
196+
) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]:
197+
completion_or_chunks = llama.create_chat_completion(
198+
**request.dict(
199+
exclude={
200+
"model",
201+
"n",
202+
"presence_penalty",
203+
"frequency_penalty",
204+
"logit_bias",
205+
"user",
206+
}
207+
),
208+
)
209+
210+
if request.stream:
211+
212+
async def server_sent_events(
213+
chat_chunks: Iterator[llama_cpp.ChatCompletionChunk],
214+
):
215+
for chat_chunk in chat_chunks:
216+
yield dict(data=json.dumps(chat_chunk))
217+
yield dict(data="[DONE]")
218+
219+
chunks: Iterator[llama_cpp.ChatCompletionChunk] = completion_or_chunks # type: ignore
220+
221+
return EventSourceResponse(
222+
server_sent_events(chunks),
223+
)
224+
completion: llama_cpp.ChatCompletion = completion_or_chunks # type: ignore
225+
return completion
226+
227+
228+
class ModelData(TypedDict):
229+
id: str
230+
object: Literal["model"]
231+
owned_by: str
232+
permissions: List[str]
233+
234+
235+
class ModelList(TypedDict):
236+
object: Literal["list"]
237+
data: List[ModelData]
238+
239+
240+
GetModelResponse = create_model_from_typeddict(ModelList)
241+
242+
243+
@app.get("/v1/models", response_model=GetModelResponse)
244+
def get_models() -> ModelList:
245+
return {
246+
"object": "list",
247+
"data": [
248+
{
249+
"id": llama.model_path,
250+
"object": "model",
251+
"owned_by": "me",
252+
"permissions": [],
253+
}
254+
],
255+
}
256+
257+
258+
if __name__ == "__main__":
259+
import os
260+
import uvicorn
261+
262+
uvicorn.run ADC5 (app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000)))

setup.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,15 @@
1414
author="Andrei Betlen",
1515
author_email="abetlen@gmail.com",
1616
license="MIT",
17-
packages=["llama_cpp"],
17+
package_dir={"llama_cpp": "llama_cpp", "llama_cpp.server": "llama_cpp/server"},
18+
packages=["llama_cpp", "llama_cpp.server"],
19+
entry_points={"console_scripts": ["llama_cpp.server=llama_cpp.server:main"]},
1820
install_requires=[
1921
"typing-extensions>=4.5.0",
2022
],
23+
extras_require={
24+
"server": ["uvicorn>=0.21.1", "fastapi>=0.95.0", "sse-starlette>=1.3.3"],
25+
},
2126
python_requires=">=3.7",
2227
classifiers=[
2328
"Programming Language :: Python :: 3",

0 commit comments

Comments
 (0)
0