8000 Merge branch 'main' into better-server-params-and-fields · Stonelinks/llama-cpp-python@7ab08b8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7ab08b8

Browse files
authored
Merge branch 'main' into better-server-params-and-fields
2 parents dbbfc4b + 46e3c4b commit 7ab08b8

File tree

5 files changed

+50
-39
lines changed

5 files changed

+50
-39
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ This package is under active development and I welcome any contributions.
9090
To get started, clone the repository and install the package in development mode:
9191

9292
```bash
93-
git clone --recurse-submodules git@github.com:abetlen/llama-cpp-python.git
93+
git clone --recurse-submodules git@github.com:abetlen/llama-cpp-python.git
9494
# Will need to be re-run any time vendor/llama.cpp is updated
9595
python3 setup.py develop
9696
```

llama_cpp/llama.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def _sample_top_p_top_k(
306306
llama_cpp.llama_sample_typical(
307307
ctx=self.ctx,
308308
candidates=llama_cpp.ctypes.pointer(candidates),
309-
p=llama_cpp.c_float(1.0)
309+
p=llama_cpp.c_float(1.0),
310310
)
311311
llama_cpp.llama_sample_top_p(
312312
ctx=self.ctx,
@@ -637,10 +637,7 @@ def _create_completion(
637637
self.detokenize([token]).decode("utf-8", errors="ignore")
638638
for token in all_tokens
639639
]
640-
all_logprobs = [
641-
Llama._logits_to_logprobs(row)
642-
for row in self.eval_logits
643-
]
640+
all_logprobs = [Llama._logits_to_logprobs(row) for row in self.eval_logits]
644641
for token, token_str, logprobs_token in zip(
645642
all_tokens, all_token_strs, all_logprobs
646643
):

llama_cpp/server/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
import os
2525
import uvicorn
2626

27-
from llama_cpp.server.app import app, init_llama
27+
from llama_cpp.server.app import create_app
2828

2929
if __name__ == "__main__":
30-
init_llama()
30+
app = create_app()
3131

3232
uvicorn.run(
3333
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))

llama_cpp/server/app.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,18 @@
22
import json
33
from threading import Lock
44
from typing import List, Optional, Union, Iterator, Dict
5-
from typing_extensions import TypedDict, Literal
5+
from typing_extensions import TypedDict, Literal, Annotated
66

77
import llama_cpp
88

9-
from fastapi import Depends, FastAPI
9+
from fastapi import Depends, FastAPI, APIRouter
1010
from fastapi.middleware.cors import CORSMiddleware
1111
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
1212
from sse_starlette.sse import EventSourceResponse
1313

1414

1515
class Settings(BaseSettings):
16-
model: str = os.environ.get("MODEL", "null")
16+
model: str
1717
n_ctx: int = 2048
1818
n_batch: int = 512
19 67E6 19
n_threads: int = max((os.cpu_count() or 2) // 2, 1)
@@ -27,25 +27,29 @@ class Settings(BaseSettings):
2727
vocab_only: bool = False
2828

2929

30-
app = FastAPI(
31-
title="🦙 llama.cpp Python API",
32-
version="0.0.1",
33-
)
34-
app.add_middleware(
35-
CORSMiddleware,
36-
allow_origins=["*"],
37-
allow_credentials=True,
38-
allow_methods=["*"],
39-
allow_headers=["*"],
40-
)
30+
router = APIRouter()
31+
32+
llama: Optional[llama_cpp.Llama] = None
4133

42-
llama: llama_cpp.Llama = None
43-
def init_llama(settings: Settings = None):
34+
35+
def create_app(settings: Optional[Settings] = None):
4436
if settings is None:
4537
settings = Settings()
38+
app = FastAPI(
39+
title="🦙 llama.cpp Python API",
40+
version="0.0.1",
41+
)
42+
app.add_middleware(
43+
CORSMiddleware,
44+
allow_origins=["*"],
45+
allow_credentials=True,
46+
allow_methods=["*"],
47+
allow_headers=["*"],
48+
)
49+
app.include_router(router)
4650
global llama
4751
llama = llama_cpp.Llama(
48-
settings.model,
52+
model_path=settings.model,
4953
f16_kv=settings.f16_kv,
5054
use_mlock=settings.use_mlock,
5155
use_mmap=settings.use_mmap,
@@ -60,8 +64,12 @@ def init_llama(settings: Settings = None):
6064
if settings.cache:
6165
cache = llama_cpp.LlamaCache()
6266
llama.set_cache(cache)
67+
return app
68+
6369

6470
llama_lock = Lock()
71+
72+
6573
def get_llama():
6674
with llama_lock:
6775
yield llama
@@ -117,8 +125,6 @@ def get_llama():
117125
"Repeat penalty is a hyperparameter used to penalize the repetition of token sequences during text generation. It helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient."
118126
)
119127

120-
121-
122128
class CreateCompletionRequest(BaseModel):
123129
prompt: Union[str, List[str]] = Field(
124130
default="",
@@ -162,7 +168,7 @@ class Config:
162168
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
163169

164170

165-
@app.post(
171+
@router.post(
166172
"/v1/completions",
167173
response_model=CreateCompletionResponse,
168174
)
@@ -204,7 +210,7 @@ class Config:
204210
CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
205211

206212

207-
@app.post(
213+
@router.post(
208214
"/v1/embeddings",
209215
response_model=CreateEmbeddingResponse,
210216
)
@@ -257,7 +263,7 @@ class Config:
257263
CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion)
258264

259265

260-
@app.post(
266+
@router.post(
261267
"/v1/chat/completions",
262268
response_model=CreateChatCompletionResponse,
263269
)
@@ -306,7 +312,7 @@ class ModelList(TypedDict):
306312
GetModelResponse = create_model_from_typeddict(ModelList)
307313

308314

309-
@app.get("/v1/models", response_model=GetModelResponse)
315+
@router.get("/v1/models", response_model=GetModelResponse)
310316
def get_models() -> ModelList:
311317
return {
312318
"object": "list",

tests/test_llama.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@ def test_llama_patch(monkeypatch):
2222
## Set up mock function
2323
def mock_eval(*args, **kwargs):
2424
return 0
25-
25+
2626
def mock_get_logits(*args, **kwargs):
27-
return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)])
27+
return (llama_cpp.c_float * n_vocab)(
28+
*[llama_cpp.c_float(0) for _ in range(n_vocab)]
29+
)
2830

2931
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
3032
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
@@ -88,6 +90,7 @@ def mock_sample(*args, **kwargs):
8890
def test_llama_pickle():
8991
import pickle
9092
import tempfile
93+
9194
fp = tempfile.TemporaryFile()
9295
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
9396
pickle.dump(llama, fp)
@@ -101,6 +104,7 @@ def test_llama_pickle():
101104

102105
assert llama.detokenize(llama.tokenize(text)) == text
103106

107+
104108
def test_utf8(monkeypatch):
105109
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
106110
n_vocab = int(llama_cpp.llama_n_vocab(llama.ctx))
@@ -110,7 +114,9 @@ def mock_eval(*args, **kwargs):
110114
return 0
111115

112116
def mock_get_logits(*args, **kwargs):
113-
return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)])
117+
return (llama_cpp.c_float * n_vocab)(
118+
*[llama_cpp.c_float(0) for _ in range(n_vocab)]
119+
)
114120

115121
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
116122
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
@@ -143,11 +149,13 @@ def mock_sample(*args, **kwargs):
143149

144150
def test_llama_server():
145151
from fastapi.testclient import TestClient
146-
from llama_cpp.server.app import app, init_llama, Settings
147-
s = Settings()
148-
s.model = MODEL
149-
s.vocab_only = True
150-
init_llama(s)
152+
from llama_cpp.server.app import create_app, Settings
153+
154+
settings = Settings(
155+
model=MODEL,
156+
vocab_only=True,
157+
)
158+
app = create_app(settings)
151159
client = TestClient(app)
152160
response = client.get("/v1/models")
153161
assert response.json() == {

0 commit comments

Comments
 (0)
0