1
1
"""Example FastAPI server for llama.cpp.
2
2
"""
3
- from typing import List , Optional
3
+ import json
4
+ from typing import List , Optional , Iterator
4
5
5
- from llama_cpp import Llama
6
+ import llama_cpp
6
7
7
8
from fastapi import FastAPI
8
- from pydantic import BaseModel , BaseSettings , Field
9
+ from fastapi .middleware .cors import CORSMiddleware
10
+ from pydantic import BaseModel , BaseSettings , Field , create_model_from_typeddict
11
+ from sse_starlette .sse import EventSourceResponse
9
12
10
13
11
14
class Settings (BaseSettings ):
@@ -16,11 +19,24 @@ class Settings(BaseSettings):
16
19
title = "🦙 llama.cpp Python API" ,
17
20
version = "0.0.1" ,
18
21
)
22
+ app .add_middleware (
23
+ CORSMiddleware ,
24
+ allow_origins = ["*" ],
25
+ allow_credentials = True ,
26
+ allow_methods = ["*" ],
27
+ allow_headers = ["*" ],
28
+ )
19
29
settings = Settings ()
20
- llama = Llama (settings .model )
30
+ llama = llama_cpp .Llama (
31
+ settings .model ,
32
+ f16_kv = True ,
33
+ use_mlock = True ,
34
+ n_threads = 6 ,
35
+ n_batch = 2048 ,
36
+ )
21
37
22
38
23
- class CompletionRequest (BaseModel ):
39
+ class CreateCompletionRequest (BaseModel ):
24
40
prompt : str
25
41
suffix : Optional [str ] = Field (None )
26
42
max_tokens : int = 16
@@ -31,6 +47,7 @@ class CompletionRequest(BaseModel):
31
47
stop : List [str ] = []
32
48
repeat_penalty : float = 1.1
33
49
top_k : int = 40
50
+ stream : bool = False
34
51
35
52
class Config :
36
53
schema_extra = {
@@ -41,6 +58,39 @@ class Config:
41
58
}
42
59
43
60
44
- @app .post ("/v1/completions" )
45
- def completions (request : CompletionRequest ):
61
+ CreateCompletionResponse = create_model_from_typeddict (llama_cpp .Completion )
62
+
63
+
64
+ @app .post (
65
+ "/v1/completions" ,
66
+ response_model = CreateCompletionResponse ,
67
+ )
68
+ def create_completion (request : CreateCompletionRequest ):
69
+ if request .stream :
70
+ chunks : Iterator [llama_cpp .CompletionChunk ] = llama (** request .dict ()) # type: ignore
71
+ return EventSourceResponse (dict (data = json .dumps (chunk )) for chunk in chunks )
46
72
return llama (** request .dict ())
73
+
74
+
75
+ class CreateEmbeddingRequest (BaseModel ):
76
+ model : Optional [str ]
77
+ input : str
78
+ user : Optional [str ]
79
+
80
+ class Config :
81
+ schema_extra = {
82
+ "example" : {
83
+ "input" : "The food was delicious and the waiter..." ,
84
+ }
85
+ }
86
+
87
+
88
+ CreateEmbeddingResponse = create_model_from_typeddict (llama_cpp .Embedding )
89
+
90
+
91
+ @app .post (
92
+ "/v1/embeddings" ,
93
+ response_model = CreateEmbeddingResponse ,
94
+ )
95
+ def create_embedding (request : CreateEmbeddingRequest ):
96
+ return llama .create_embedding (** request .dict ())
0 commit comments