8000 bugfix: pydantic v2 fields · wdshin/llama-cpp-python@de4cc5a · GitHub
[go: up one dir, main page]

Skip to content

Commit de4cc5a

Browse files
committed
bugfix: pydantic v2 fields
1 parent 896ab7b commit de4cc5a

File tree

1 file changed

+50
-58
lines changed

1 file changed

+50
-58
lines changed

llama_cpp/server/app.py

Lines changed: 50 additions & 58 deletions
Original f 8000 ile line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@ class Settings(BaseSettings):
3131
ge=0,
3232
description="The number of layers to put on the GPU. The rest will be on the CPU.",
3333
)
34-
seed: int = Field(
35-
default=1337, description="Random seed. -1 for random."
36-
)
34+
seed: int = Field(default=1337, description="Random seed. -1 for random.")
3735
n_batch: int = Field(
3836
default=512, ge=1, description="The batch size to use per eval."
3937
)
@@ -80,12 +78,8 @@ class Settings(BaseSettings):
8078
verbose: bool = Field(
8179
default=True, description="Whether to print debug information."
8280
)
83-
host: str = Field(
84-
default="localhost", description="Listen address"
85-
)
86-
port: int = Field(
87-
default=8000, description="Listen port"
88-
)
81+
host: str = Field(default="localhost", description="Listen address")
82+
port: int = Field(default=8000, description="Listen port")
8983
interrupt_requests: bool = Field(
9084
default=True,
9185
description="Whether to interrupt requests when a new request is received.",
@@ -178,7 +172,7 @@ def get_settings():
178172
yield settings
179173

180174

181-
model_field = Field(description="The model to use for generating completions.")
175+
model_field = Field(description="The model to use for generating completions.", default=None)
182176

183177
max_tokens_field = Field(
184178
default=16, ge=1, le=2048, description="The maximum number of tokens to generate."
@@ -242,21 +236,18 @@ def get_settings():
242236
default=0,
243237
ge=0,
244238
le=2,
245-
description="Enable Mirostat constant-perplexity algorithm of the specified version (1 or 2; 0 = disabled)"
239+
description="Enable Mirostat constant-perplexity algorithm of the specified version (1 or 2; 0 = disabled)",
246240
)
247241

248242
mirostat_tau_field = Field(
249243
default=5.0,
250244
ge=0.0,
251245
le=10.0,
252-
description="Mirostat target entropy, i.e. the target perplexity - lower values produce focused and coherent text, larger values produce more diverse and less coherent text"
246+
description="Mirostat target entropy, i.e. the target perplexity - lower values produce focused and coherent text, larger values produce more diverse and less coherent text",
253247
)
254248

255249
mirostat_eta_field = Field(
256-
default=0.1,
257-
ge=0.001,
258-
le=1.0,
259-
description="Mirostat learning rate"
250+
default=0.1, ge=0.001, le=1.0, description="Mirostat learning rate"
260251
)
261252

262253

@@ -294,22 +285,23 @@ class CreateCompletionRequest(BaseModel):
294285
model: Optional[str] = model_field
295286
n: Optional[int] = 1
296287
best_of: Optional[int] = 1
297-
user: Optional[str] = Field(None)
288+
user: Optional[str] = Field(default=None)
298289

299290
# llama.cpp specific parameters
300291
top_k: int = top_k_field
301292
repeat_penalty: float = repeat_penalty_field
302293
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
303294

304-
class Config:
305-
schema_extra = {
306-
"example": {
307-
"prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n",
308-
"stop": ["\n", "###"],
309-
}
295+
model_config = {
296+
"json_schema_extra": {
297+
"examples": [
298+
{
299+
"prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n",
300+
"stop": ["\n", "###"],
301+
}
302+
]
310303
}
311-
312-
304+
}
313305

314306

315307
def make_logit_bias_processor(
@@ -328,7 +320,7 @@ def make_logit_bias_processor(
328320

329321
elif logit_bias_type == "tokens":
330322
for token, score in logit_bias.items():
331-
token = token.encode('utf-8')
323+
token = token.encode("utf-8")
332324
for input_id in llama.tokenize(token, add_bos=False):
333325
to_bias[input_id] = score
334326

@@ -352,7 +344,7 @@ async def create_completion(
352344
request: Request,
353345
body: CreateCompletionRequest,
354346
llama: llama_cpp.Llama = Depends(get_llama),
355-
):
347+
) -> llama_cpp.Completion:
356348
if isinstance(body.prompt, list):
357349
assert len(body.prompt) <= 1
358350
body.prompt = body.prompt[0] if len(body.prompt) > 0 else ""
@@ -364,7 +356,7 @@ async def create_completion(
364356
"logit_bias_type",
365357
"user",
366358
}
367-
kwargs = body.dict(exclude=exclude)
359+
kwargs = body.model_dump(exclude=exclude)
368360

369361
if body.logit_bias is not None:
370362
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([
@@ -396,7 +388,7 @@ async def event_publisher(inner_send_chan: MemoryObjectSendStream):
396388

397389
return EventSourceResponse(
398390
recv_chan, data_sender_callable=partial(event_publisher, send_chan)
399-
)
391+
) # type: ignore
400392
else:
401393
completion: llama_cpp.Completion = await run_in_threadpool(llama, **kwargs) # type: ignore
402394
return completion
@@ -405,16 +397,17 @@ async def event_publisher(inner_send_chan: MemoryObjectSendStream):
405397
class CreateEmbeddingRequest(BaseModel):
406398
model: Optional[str] = model_field
407399
input: Union[str, List[str]] = Field(description="The input to embed.")
408-
user: Optional[str]
409-
410-
class Config:
411-
schema_extra = {
412-
"example": {
413-
"input": "The food was delicious and the waiter...",
414-
}
400+
user: Optional[str] = Field(default=None)
401+
402+
model_config = {
403+
"json_schema_extra": {
404+
"examples": [
405+
{
406+
"input": "The food was delicious and the waiter...",
407+
}
408+
]
415409
}
416-
417-
410+
}
418411

419412

420413
@router.post(
@@ -424,7 +417,7 @@ async def create_embedding(
424417
request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama)
425418
):
426419
return await run_in_threadpool(
427-
llama.create_embedding, **request.dict(exclude={"user"})
420+
llama.create_embedding, **request.model_dump(exclude={"user"})
428421
)
429422

430423

@@ -461,21 +454,22 @@ class CreateChatCompletionRequest(BaseModel):
461454
repeat_penalty: float = repeat_penalty_field
462455
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
463456

464-
class Config:
465-
schema_extra = {
466-
"example": {
467-
"messages": [
468-
ChatCompletionRequestMessage(
469-
role="system", content="You are a helpful assistant."
470-
),
471-
ChatCompletionRequestMessage(
472-
role="user", content="What is the capital of France?"
473-
),
474-
]
475-
}
457+
model_config = {
458+
"json_schema_extra": {
459+
"examples": [
460+
{
461+
"messages": [
462+
ChatCompletionRequestMessage(
463+
role="system", content="You are a helpful assistant."
464+
).model_dump(),
465+
ChatCompletionRequestMessage(
466+
role="user", content="What is the capital of France?"
467+
).model_dump(),
468+
]
469+
}
470+
]
476471
}
477-
478-
472+
}
479473

480474

481475
@router.post(
@@ -486,14 +480,14 @@ async def create_chat_completion(
486480
body: CreateChatCompletionRequest,
487481
llama: llama_cpp.Llama = Depends(get_llama),
488482
settings: Settings = Depends(get_settings),
489-
) -> Union[llama_cpp.ChatCompletion]: # type: ignore
483+
) -> llama_cpp.ChatCompletion:
490484
exclude = {
491485
"n",
492486
"logit_bias",
493487
"logit_bias_type",
494488
"user",
495489
}
496-
kwargs = body.dict(exclude=exclude)
490+
kwargs = body.model_dump(exclude=exclude)
497491

498492
if body.logit_bias is not None:
499493
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([
@@ -526,7 +520,7 @@ async def event_publisher(inner_send_chan: MemoryObjectSendStream):
526520
return EventSourceResponse(
527521
recv_chan,
528522
data_sender_callable=partial(event_publisher, send_chan),
529-
)
523+
) # type: ignore
530524
else:
531525
completion: llama_cpp.ChatCompletion = await run_in_threadpool(
532526
llama.create_chat_completion, **kwargs # type: ignore
@@ -546,8 +540,6 @@ class ModelList(TypedDict):
546540
data: List[ModelData]
547541

548542

549-
550-
551543
@router.get("/v1/models")
552544
async def get_models(
553545
settings: Settings = Depends(get_settings),

0 commit comments

Comments
 (0)
0