8000 Add grammar_lazy sampler · utilityai/llama-cpp-rs@dabcb10 · GitHub
[go: up one dir, main page]

Skip to content

Commit dabcb10

Browse files
committed
Add grammar_lazy sampler
1 parent 8b11c5c commit dabcb10

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

llama-cpp-2/src/sampling.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,49 @@ impl LlamaSampler {
239239
Self { sampler }
240240
}
241241

242+
/// Lazy grammar sampler, introduced in <https://github.com/ggerganov/llama.cpp/pull/9639>
243+
///
244+
/// This sampler enforces grammar rules only when specific trigger words or tokens are encountered.
245+
///
246+
/// # Panics
247+
/// - If `grammar_str` or `grammar_root` contain null bytes
248+
/// - If any trigger word contains null bytes
249+
#[must_use]
250+
pub fn grammar_lazy(
251+
model: &LlamaModel,
252+
grammar_str: &str,
253+
grammar_root: &str,
254+
trigger_words: impl IntoIterator<Item = impl AsRef<[u8]>>,
255+
trigger_tokens: &[LlamaToken],
256+
) -> Self {
257+
let grammar_str = CString::< EB49 span class=pl-en>new(grammar_str).unwrap();
258+
let grammar_root = CString::new(grammar_root).unwrap();
259+
260+
let trigger_word_cstrings: Vec<CString> = trigger_words
261+
.into_iter()
262+
.map(|word| CString::new(word.as_ref()).unwrap())
263+
.collect();
264+
265+
let mut trigger_word_ptrs: Vec<*const c_char> = trigger_word_cstrings
266+
.iter()
267+
.map(|cs| cs.as_ptr())
268+
.collect();
269+
270+
let sampler = unsafe {
271+
llama_cpp_sys_2::llama_sampler_init_grammar_lazy(
272+
model.vocab_ptr(),
273+
grammar_str.as_ptr(),
274+
grammar_root.as_ptr(),
275+
trigger_word_ptrs.as_mut_ptr(),
276+
trigger_word_ptrs.len(),
277+
trigger_tokens.as_ptr().cast(),
278+
trigger_tokens.len(),
279+
)
280+
};
281+
282+
Self { sampler }
283+
}
284+
242285
/// DRY sampler, designed by p-e-w, as described in:
243286
/// <https://github.com/oobabooga/text-generation-webui/pull/5677>, porting Koboldcpp
244287
/// implementation authored by pi6am: <https://github.com/LostRuins/koboldcpp/pull/982>

0 commit comments

Comments
 (0)
0