8000 Fix threading bug. Closes #62 · Co-Simulation/llama-cpp-python@19598ac · GitHub
[go: up one dir, main page]

Skip to content
8000

Commit 19598ac

Browse files
committed
Fix threading bug. Closes abetlen#62
1 parent 005c78d commit 19598ac

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

llama_cpp/server/__main__.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
"""
1414
import os
1515
import json
16+
from threading import Lock
1617
from typing import List, Optional, Literal, Union, Iterator, Dict
1718
from typing_extensions import TypedDict
1819

1920
import llama_cpp
2021

21-
from fastapi import FastAPI
22+
from fastapi import Depends, FastAPI
2223
from fastapi.middleware.cors import CORSMiddleware
2324
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
2425
from sse_starlette.sse import EventSourceResponse
@@ -59,6 +60,13 @@ class Settings(BaseSettings):
5960
n_ctx=settings.n_ctx,
6061
last_n_tokens_size=settings.last_n_tokens_size,
6162
)
63+
llama_lock = Lock()
64+
65+
66+
def get_llama():
67+
with llama_lock:
68+
yield llama
69+
6270

6371

6472
class CreateCompletionRequest(BaseModel):
@@ -101,7 +109,7 @@ class Config:
101109
"/v1/completions",
102110
response_model=CreateCompletionResponse,
103111
)
104-
def create_completion(request: CreateCompletionRequest):
112+
def create_completion(request: CreateCompletionRequest, llama: llama_cpp.Llama=Depends(get_llama)):
105113
if isinstance(request.prompt, list):
106114
request.prompt = "".join(request.prompt)
107115

@@ -146,7 +154,7 @@ class Config:
146154
"/v1/embeddings",
147155
response_model=CreateEmbeddingResponse,
148156
)
149-
def create_embedding(request: CreateEmbeddingRequest):
157+
def create_embedding(request: CreateEmbeddingRequest, llama: llama_cpp.Llama=Depends(get_llama)):
150158
return llama.create_embedding(**request.dict(exclude={"model", "user"}))
151159

152160

@@ -200,6 +208,7 @@ class Config:
200208
)
201209
def create_chat_completion(
202210
request: CreateChatCompletionRequest,
211+
llama: llama_cpp.Llama=Depends(get_llama),
203212
) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]:
204213
completion_or_chunks = llama.create_chat_completion(
205214
**request.dict(

0 commit comments

Comments
 (0)
0