8000 Move grammar to function call argument · hackensun/llama-cpp-python@66fb034 · GitHub
[go: up one dir, main page]

Skip to content

Commit 66fb034

Browse files
committed
Move grammar to function call argument
1 parent 1e844d3 commit 66fb034

File tree

1 file changed

+19
-14
lines changed

1 file changed

+19
-14
lines changed

llama_cpp/llama.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,6 @@ def __init__(
227227
tensor_split: Optional[List[float]] = None,
228228
rope_freq_base: float = 10000.0,
229229
rope_freq_scale: float = 1.0,
230-
grammar: Optional[Union[str, Path]] = None,
231230
n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b
232231
rms_norm_eps: Optional[float] = None, # (TEMPORARY)
233232
mul_mat_q: Optional[bool] = None, # (TEMPORARY)
@@ -254,7 +253,6 @@ def __init__(
254253
tensor_split: List of floats to split the model across multiple GPUs. If None, the model is not split.
255254
rope_freq_base: Base frequency for rope sampling.
256255
rope_freq_scale: Scale factor for rope sampling.
257-
grammar: Path to a BNF grammar file to use for grammar based sampling.
258256
verbose: Print verbose output to stderr.
259257
260258
Raises:
@@ -383,12 +381,6 @@ def __init__(
383381
self.scores: npt.NDArray[np.single] = np.ndarray(
384382
(n_ctx, self._n_vocab), dtype=np.single
385383
)
386-
if grammar is not None:
387-
self.grammar = LlamaGrammar.from_file(
388-
grammar, verbose=verbose
389-
) # type: Optional[LlamaGrammar]
390-
else:
391-
self.grammar = None
392384

393385
@property
394386
def _input_ids(self) -> npt.NDArray[np.intc]:
@@ -527,6 +519,7 @@ def _sample(
527519
mirostat_eta: llama_cpp.c_float,
528520
penalize_nl: bool = True,
529521
logits_processor: Optional[LogitsProcessorList] = None,
522+
grammar: Optional[LlamaGrammar] = None,
530523
):
531524
assert self.ctx is not None
532525
assert self.n_tokens > 0
@@ -574,11 +567,11 @@ def _sample(
574567
if not penalize_nl:
575568
candidates.data[self._token_nl].logit = llama_cpp.c_float(nl_logit)
576569

577-
if self.grammar is not None:
570+
if grammar is not None:
578571
llama_cpp.llama_sample_grammar(
579572
ctx=self.ctx,
580573
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
581-
grammar=self.grammar.grammar,
574+
grammar=grammar.grammar,
582575
)
583576

584577
if temp.value == 0.0:
@@ -650,10 +643,10 @@ def _sample(
650643
ctx=self.ctx,
651644
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
652645
)
653-
if self.grammar is not None:
646+
if grammar is not None:
654647
llama_cpp.llama_grammar_accept_token(
655648
ctx=self.ctx,
656-
grammar=self.grammar.grammar,
649+
grammar=grammar.grammar,
657650
token=llama_cpp.ctypes.c_int(id),
658651
)
659652
return id
@@ -672,6 +665,7 @@ def sample(
672665
mirostat_tau: float = 5.0,
673666
penalize_nl: bool = True,
674667
logits_processor: Optional[LogitsProcessorList] = None,
668+
grammar: Optional[LlamaGrammar] = None,
675669
):
676670
"""Sample a token from the model.
677671
@@ -705,6 +699,7 @@ def sample(
705699
mirostat_eta=llama_cpp.c_float(mirostat_eta),
706700
penalize_nl=penalize_nl,
707701
logits_processor=logits_processor,
702+
grammar=grammar,
708703
)
709704

710705
def generate(
@@ -723,6 +718,7 @@ def generate(
723718
mirostat_eta: float = 0.1,
724719
logits_processor: Optional[LogitsProcessorList] = None,
725720
stopping_criteria: Optional[StoppingCriteriaList] = None,
721+
grammar: Optional[LlamaGrammar] = None,
726722
) -> Generator[int, Optional[Sequence[int]], None]:
727723
"""Create a generator of tokens from a prompt.
728724
@@ -761,8 +757,8 @@ def generate(
761757
if reset:
762758
self.reset()
763759

764-
if self.grammar is not None:
765-
self.grammar.reset()
760+
if grammar is not None:
761+
grammar.reset()
766762

767763
while True:
768764
self.eval(tokens)
@@ -778,6 +774,7 @@ def generate(
778774
mirostat_tau=mirostat_tau,
779775
mirostat_eta=mirostat_eta,
780776
logits_processor=logits_processor,
777+
grammar=grammar,
781778
)
782779
if stopping_criteria is not None and stopping_criteria(
783780
self._input_ids.tolist(), self._scores[-1, :].tolist()
@@ -880,6 +877,7 @@ def _create_completion(
880877
model: Optional[str] = None,
881878
stopping_criteria: Optional[StoppingCriteriaList] = None,
882879
logits_processor: Optional[LogitsProcessorList] = None,
880+
grammar: Optional[LlamaGrammar] = None,
883881
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
884882
assert self.ctx is not None
885883

@@ -957,6 +955,7 @@ def _create_completion(
957955
repeat_penalty=repeat_penalty,
958956
stopping_criteria=stopping_criteria,
959957
logits_processor=logits_processor,
958+
grammar=grammar,
960959
):
961960
if token == self._token_eos:
962961
text = self.detokenize(completion_tokens)
@@ -1301,6 +1300,7 @@ def create_completion(
13011300
model: Optional[str] = None,
13021301
stopping_criteria: Optional[StoppingCriteriaList] = None,
13031302
logits_processor: Optional[LogitsProcessorList] = None,
1303+
grammar: Optional[LlamaGrammar] = None,
13041304
) -> Union[Completion, Iterator[CompletionChunk]]:
13051305
"""Generate text from a prompt.
13061306
@@ -1345,6 +1345,7 @@ def create_completion(
13451345
model=model,
13461346
stopping_criteria=stopping_criteria,
13471347
logits_processor=logits_processor,
1348+
grammar=grammar
13481349
)
13491350
if stream:
13501351
chunks: Iterator[CompletionChunk] = completion_or_chunks
@@ -1374,6 +1375,7 @@ def __call__(
13741375
model: Optional[str] = None,
13751376
stopping_criteria: Optional[StoppingCriteriaList] = None,
13761377
logits_processor: Optional[LogitsProcessorList] = None,
1378+
grammar: Optional[LlamaGrammar] = None,
13771379
) -> Union[Completion, Iterator[CompletionChunk]]:
13781380
"""Generate text from a prompt.
13791381
@@ -1418,6 +1420,7 @@ def __call__(
14181420
model=model,
14191421
stopping_criteria=stopping_criteria,
14201422
logits_processor=logits_processor,
1423+
grammar=grammar,
14211424
)
14221425

14231426
def _convert_text_completion_to_chat(
@@ -1498,6 +1501,7 @@ def create_chat_completion(
14981501
mirostat_eta: float = 0.1,
14991502
model: Optional[str] = None,
15001503
logits_processor: Optional[LogitsProcessorList] = None,
1504+
grammar: Optional[LlamaGrammar] = None,
15011505
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
15021506
"""Generate a chat completion from a list of messages.
15031507
@@ -1540,6 +1544,7 @@ def create_chat_completion(
15401544
mirostat_eta=mirostat_eta,
15411545
model=model,
15421546
logits_processor=logits_processor,
1547+
grammar=grammar,
15431548
)
15441549
if stream:
15451550
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore

0 commit comments

Comments
 (0)
0