|
13 | 13 | """
|
14 | 14 | import os
|
15 | 15 | import json
|
| 16 | +from threading import Lock |
16 | 17 | from typing import List, Optional, Literal, Union, Iterator, Dict
|
17 | 18 | from typing_extensions import TypedDict
|
18 | 19 |
|
19 | 20 | import llama_cpp
|
20 | 21 |
|
21 |
| -from fastapi import FastAPI |
| 22 | +from fastapi import Depends, FastAPI |
22 | 23 | from fastapi.middleware.cors import CORSMiddleware
|
23 | 24 | from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
|
24 | 25 | from sse_starlette.sse import EventSourceResponse
|
@@ -59,6 +60,13 @@ class Settings(BaseSettings):
|
59 | 60 | n_ctx=settings.n_ctx,
|
60 | 61 | last_n_tokens_size=settings.last_n_tokens_size,
|
61 | 62 | )
|
| 63 | +llama_lock = Lock() |
| 64 | + |
| 65 | + |
| 66 | +def get_llama(): |
| 67 | + with llama_lock: |
| 68 | + yield llama |
| 69 | + |
62 | 70 |
|
63 | 71 |
|
64 | 72 | class CreateCompletionRequest(BaseModel):
|
@@ -101,7 +109,7 @@ class Config:
|
101 | 109 | "/v1/completions",
|
102 | 110 | response_model=CreateCompletionResponse,
|
103 | 111 | )
|
104 |
| -def create_completion(request: CreateCompletionRequest): |
| 112 | +def create_completion(request: CreateCompletionRequest, llama: llama_cpp.Llama=Depends(get_llama)): |
105 | 113 | if isinstance(request.prompt, list):
|
106 | 114 | request.prompt = "".join(request.prompt)
|
107 | 115 |
|
@@ -146,7 +154,7 @@ class Config:
|
146 | 154 | "/v1/embeddings",
|
147 | 155 | response_model=CreateEmbeddingResponse,
|
148 | 156 | )
|
149 |
| -def create_embedding(request: CreateEmbeddingRequest): |
| 157 | +def create_embedding(request: CreateEmbeddingRequest, llama: llama_cpp.Llama=Depends(get_llama)): |
150 | 158 | return llama.create_embedding(**request.dict(exclude={"model", "user"}))
|
151 | 159 |
|
152 | 160 |
|
@@ -200,6 +208,7 @@ class Config:
|
200 | 208 | )
|
201 | 209 | def create_chat_completion(
|
202 | 210 | request: CreateChatCompletionRequest,
|
| 211 | + llama: llama_cpp.Llama=Depends(get_llama), |
203 | 212 | ) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]:
|
204 | 213 | completion_or_chunks = llama.create_chat_completion(
|
205 | 214 | **request.dict(
|
|
0 commit comments