@@ -518,6 +518,10 @@ async def get_event_publisher(
518
518
default = 0.1 , ge = 0.001 , le = 1.0 , description = "Mirostat learning rate"
519
519
)
520
520
521
+ grammar = Field (
522
+ default = None ,
523
+ description = "A CBNF grammar (as string) to be used for formatting the model's output."
524
+ )
521
525
522
526
class CreateCompletionRequest (BaseModel ):
523
527
prompt : Union [str , List [str ]] = Field (
@@ -533,6 +537,7 @@ class CreateCompletionRequest(BaseModel):
533
537
mirostat_mode : int = mirostat_mode_field
534
538
mirostat_tau : float = mirostat_tau_field
535
539
mirostat_eta : float = mirostat_eta_field
540
+ grammar : Optional [str ] = None
536
541
echo : bool = Field (
537
542
default = False ,
538
543
description = "Whether to echo the prompt in the generated text. Useful for chatbots." ,
@@ -634,6 +639,9 @@ async def create_completion(
634
639
]
635
640
)
636
641
642
+ if body .grammar is not None :
643
+ kwargs ["grammar" ] = llama_cpp .LlamaGrammar .from_string (body .grammar )
644
+
637
645
iterator_or_completion : Union [
638
646
llama_cpp .Completion , Iterator [llama_cpp .CompletionChunk ]
639
647
] = await run_in_threadpool (llama , ** kwargs )
@@ -714,6 +722,7 @@ class CreateChatCompletionRequest(BaseModel):
714
722
mirostat_mode : int = mirostat_mode_field
715
723
mirostat_tau : float = mirostat_tau_field
716
724
mirostat_eta : float = mirostat_eta_field
725
+ grammar : Optional [str ] = None
717
726
stop : Optional [List [str ]] = stop_field
718
727
stream : bool = stream_field
719
728
presence_penalty : Optional [float ] = presence_penalty_field
@@ -772,6 +781,9 @@ async def create_chat_completion(
772
781
]
773
782
)
774
783
784
+ if body .grammar is not None :
785
+ kwargs ["grammar" ] = llama_cpp .LlamaGrammar .from_string (body .grammar )
786
+
775
787
iterator_or_completion : Union [
776
788
llama_cpp .ChatCompletion , Iterator [llama_cpp .ChatCompletionChunk ]
777
789
] = await run_in_threadpool (llama .create_chat_completion , ** kwargs )
0 commit comments