@@ -227,7 +227,6 @@ def __init__(
227
227
tensor_split : Optional [List [float ]] = None ,
228
228
rope_freq_base : float = 10000.0 ,
229
229
rope_freq_scale : float = 1.0 ,
230
- grammar : Optional [Union [str , Path ]] = None ,
231
230
n_gqa : Optional [int ] = None , # (TEMPORARY) must be 8 for llama2 70b
232
231
rms_norm_eps : Optional [float ] = None , # (TEMPORARY)
233
232
mul_mat_q : Optional [bool ] = None , # (TEMPORARY)
@@ -254,7 +253,6 @@ def __init__(
254
253
tensor_split: List of floats to split the model across multiple GPUs. If None, the model is not split.
255
254
rope_freq_base: Base frequency for rope sampling.
256
255
rope_freq_scale: Scale factor for rope sampling.
257
- grammar: Path to a BNF grammar file to use for grammar based sampling.
258
256
verbose: Print verbose output to stderr.
259
257
260
258
Raises:
@@ -383,12 +381,6 @@ def __init__(
383
381
self .scores : npt .NDArray [np .single ] = np .ndarray (
384
382
(n_ctx , self ._n_vocab ), dtype = np .single
385
383
)
386
- if grammar is not None :
387
- self .grammar = LlamaGrammar .from_file (
388
- grammar , verbose = verbose
389
- ) # type: Optional[LlamaGrammar]
390
- else :
391
- self .grammar = None
392
384
393
385
@property
394
386
def _input_ids (self ) -> npt .NDArray [np .intc ]:
@@ -527,6 +519,7 @@ def _sample(
527
519
mirostat_eta : llama_cpp .c_float ,
528
520
penalize_nl : bool = True ,
529
521
logits_processor : Optional [LogitsProcessorList ] = None ,
522
+ grammar : Optional [LlamaGrammar ] = None ,
530
523
):
531
524
assert self .ctx is not None
532
525
assert self .n_tokens > 0
@@ -574,11 +567,11 @@ def _sample(
574
567
if not penalize_nl :
575
568
candidates .data [self ._token_nl ].logit = llama_cpp .c_float (nl_logit )
576
569
577
- if self . grammar is not None :
570
+ if grammar is not None :
578
571
llama_cpp .llama_sample_grammar (
579
572
ctx = self .ctx ,
580
573
candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
581
- grammar = self . grammar .grammar ,
574
+ grammar = grammar .grammar ,
582
575
)
583
576
584
577
if temp .value == 0.0 :
@@ -650,10 +643,10 @@ def _sample(
650
643
ctx = self .ctx ,
651
644
candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
652
645
)
653
- if self . grammar is not None :
646
+ if grammar is not None :
654
647
llama_cpp .llama_grammar_accept_token (
655
648
ctx = self .ctx ,
656
- grammar = self . grammar .grammar ,
649
+ grammar = grammar .grammar ,
657
650
token = llama_cpp .ctypes .c_int (id ),
658
651
)
659
652
return id
@@ -672,6 +665,7 @@ def sample(
672
665
mirostat_tau : float = 5.0 ,
673
666
penalize_nl : bool = True ,
674
667
logits_processor : Optional [LogitsProcessorList ] = None ,
668
+ grammar : Optional [LlamaGrammar ] = None ,
675
669
):
676
670
"""Sample a token from the model.
677
671
@@ -705,6 +699,7 @@ def sample(
705
699
mirostat_eta = llama_cpp .c_float (mirostat_eta ),
706
700
penalize_nl = penalize_nl ,
707
701
logits_processor = logits_processor ,
702
+ grammar = grammar ,
708
703
)
709
704
710
705
def generate (
@@ -723,6 +718,7 @@ def generate(
723
718
mirostat_eta : float = 0.1 ,
724
719
logits_processor : Optional [LogitsProcessorList ] = None ,
725
720
stopping_criteria : Optional [StoppingCriteriaList ] = None ,
721
+ grammar : Optional [LlamaGrammar ] = None ,
726
722
) -> Generator [int , Optional [Sequence [int ]], None ]:
727
723
"""Create a generator of tokens from a prompt.
728
724
@@ -761,8 +757,8 @@ def generate(
761
757
if reset :
762
758
self .reset ()
763
759
764
- if self . grammar is not None :
765
- self . grammar .reset ()
760
+ if grammar is not None :
761
+ grammar .reset ()
766
762
767
763
while True :
768
764
self .eval (tokens )
@@ -778,6 +774,7 @@ def generate(
778
774
mirostat_tau = mirostat_tau ,
779
775
mirostat_eta = mirostat_eta ,
780
776
logits_processor = logits_processor ,
777
+ grammar = grammar ,
781
778
)
782
779
if stopping_criteria is not None and stopping_criteria (
783
780
self ._input_ids .tolist (), self ._scores [- 1 , :].tolist ()
@@ -880,6 +877,7 @@ def _create_completion(
880
877
model : Optional [str ] = None ,
881
878
stopping_criteria : Optional [StoppingCriteriaList ] = None ,
882
879
logits_processor : Optional [LogitsProcessorList ] = None ,
880
+ grammar : Optional [LlamaGrammar ] = None ,
883
881
) -> Union [Iterator [Completion ], Iterator [CompletionChunk ]]:
884
882
assert self .ctx is not None
885
883
@@ -957,6 +955,7 @@ def _create_completion(
957
955
repeat_penalty = repeat_penalty ,
958
956
stopping_criteria = stopping_criteria ,
959
957
logits_processor = logits_processor ,
958
+ grammar = grammar ,
960
959
):
961
960
if token == self ._token_eos :
962
961
text = self .detokenize (completion_tokens )
@@ -1301,6 +1300,7 @@ def create_completion(
1301
1300
model : Optional [str ] = None ,
1302
1301
stopping_criteria : Optional [StoppingCriteriaList ] = None ,
1303
1302
logits_processor : Optional [LogitsProcessorList ] = None ,
1303
+ grammar : Optional [LlamaGrammar ] = None ,
1304
1304
) -> Union [Completion , Iterator [CompletionChunk ]]:
1305
1305
"""Generate text from a prompt.
1306
1306
@@ -1345,6 +1345,7 @@ def create_completion(
1345
1345
model = model ,
1346
1346
stopping_criteria = stopping_criteria ,
1347
1347
logits_processor = logits_processor ,
1348
+ grammar = grammar
1348
1349
)
1349
1350
if stream :
1350
1351
chunks : Iterator [CompletionChunk ] = completion_or_chunks
@@ -1374,6 +1375,7 @@ def __call__(
1374
1375
model : Optional [str ] = None ,
1375
1376
stopping_criteria : Optional [StoppingCriteriaList ] = None ,
1376
1377
logits_processor : Optional [LogitsProcessorList ] = None ,
1378
+ grammar : Optional [LlamaGrammar ] = None ,
1377
1379
) -> Union [Completion , Iterator [CompletionChunk ]]:
1378
1380
"""Generate text from a prompt.
1379
1381
@@ -1418,6 +1420,7 @@ def __call__(
1418
1420
model = model ,
1419
1421
stopping_criteria = stopping_criteria ,
1420
1422
logits_processor = logits_processor ,
1423
+ grammar = grammar ,
1421
1424
)
1422
1425
1423
1426
def _convert_text_completion_to_chat (
@@ -1498,6 +1501,7 @@ def create_chat_completion(
1498
1501
mirostat_eta : float = 0.1 ,
1499
1502
model : Optional [str ] = None ,
1500
1503
logits_processor : Optional [LogitsProcessorList ] = None ,
1504
+ grammar : Optional [LlamaGrammar ] = None ,
1501
1505
) -> Union [ChatCompletion , Iterator [ChatCompletionChunk ]]:
1502
1506
"""Generate a chat completion from a list of messages.
1503
1507
@@ -1540,6 +1544,7 @@ def create_chat_completion(
1540
1544
mirostat_eta = mirostat_eta ,
1541
1545
model = model ,
1542
1546
logits_processor = logits_processor ,
1547
+ grammar = grammar ,
1543
1548
)
1544
1549
if stream :
1545
1550
chunks : Iterator [CompletionChunk ] = completion_or_chunks # type: ignore
0 commit comments