10000 Merge pull request #1 from nkoppel/grammar_sampler_stage · masmullin2000/llama_cpp-rs@539ea2c · GitHub
[go: up one dir, main page]

Skip to content

Commit 539ea2c

Browse files
authored
Merge pull request edgenai#1 from nkoppel/grammar_sampler_stage
Implement grammar sampling as a SamplerStage.
2 parents b6a8a06 + e7e0f93 commit 539ea2c

File tree

2 files changed

+64
-30
lines changed

2 files changed

+64
-30
lines changed

crates/llama_cpp/src/session/params.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ pub struct SessionParams {
105105
pub pooling: PoolingType,
106106

107107
/// defragment the KV cache if holes/size > thold, < 0 disabled (default)
108-
defrag_threshold: f32,
108+
pub defrag_threshold: f32,
109109
}
110110

111111
impl Default for SessionParams {

crates/llama_cpp/src/standard_sampler.rs

Lines changed: 63 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@ use crate::{grammar::LlamaGrammar, Sampler, Token};
1414
///
1515
/// Standard ordering for samplers (taken from [kobold.cpp](https://github.com/LostRuins/koboldcpp)):
1616
///
17-
/// 1. [`SamplerStage::RepetitionPenalty`]
18-
/// 2. [`SamplerStage::Temperature`], [SamplerStage::DynamicTemperature]
19-
/// 3. [`SamplerStage::TopK`]
20-
/// 4. [`SamplerStage::TailFree`]
21-
/// 5. [`SamplerStage::Typical`]
22-
/// 6. [`SamplerStage::TopP`], [`SamplerStage::MinP`]
17+
/// 1. [`SamplerStage::Grammar`]
18+
/// 2. [`SamplerStage::RepetitionPenalty`]
19+
/// 3. [`SamplerStage::Temperature`], [SamplerStage::DynamicTemperature]
20+
/// 4. [`SamplerStage::TopK`]
21+
/// 5. [`SamplerStage::TailFree`]
22+
/// 6. [`SamplerStage::Typical`]
23+
/// 7. [`SamplerStage::TopP`], [`SamplerStage::MinP`]
2324
#[derive(Clone, Debug)]
2425
#[non_exhaustive]
2526
pub enum SamplerStage {
@@ -103,16 +104,34 @@ pub enum SamplerStage {
103104
///
104105
/// See: <https://www.trentonbricken.com/Tail-Free-Sampling/>
105106
TailFree(f32),
107+
108+
/// A stage that uses a [`LlamaGrammar`] to remove tokens that do not align with a given
109+
/// grammar. Since this stage has to handle mutable state, an instance of this stage should
110+
/// only be used in one completion.
111+
///
112+
/// See [`GrammarStage`] and [`LlamaGrammar`] for more information.
113+
Grammar(GrammarStage),
106114
}
107115

108116
impl SamplerStage {
117+
/// Creates a new [`SamplerStage::Grammar`] from a [`LlamaGrammar`].
118+
///
119+
/// `start_position` indicates the token position to begin applying the grammar at. [`None`]
120+
/// indicates that the grammar begins at the end of context.
121+
pub fn from_grammar(grammar: LlamaGrammar, start_position: Option<usize>) -> Self {
122+
SamplerStage::Grammar(GrammarStage {
123+
grammar,
124+
accepted_up_to: start_position,
125+
})
126+
}
127+
109128
/// Applies this [`SamplerStage`] to the provided token data array.
110129
///
111130
/// Ensures that at least `min_keep` tokens remain after the
112131
/// [`SamplerStage`]'s are applied.
113132
#[allow(clippy::not_unsafe_ptr_arg_deref)]
114133
pub fn apply(
115-
&self,
134+
&mut self,
116135
context: *mut llama_context,
117136
tokens: &[Token],
118137
mut candidates_p: llama_token_data_array,
@@ -173,13 +192,48 @@ impl SamplerStage {
173192
SamplerStage::TailFree(z) => {
174193
llama_sample_tail_free(context, p_ptr, *z, min_keep);
175194
}
195+
SamplerStage::Grammar(stage) => {
196+
candidates_p = stage.apply(context, tokens, candidates_p, min_keep)
197+
}
176198
}
177199
}
178200

179201
candidates_p
180202
}
181203
}
182204

205+
/// Opaque internals for [`SamplerStage::Grammar`].
206+
#[derive(Clone, Debug)]
207+
pub struct GrammarStage {
208+
grammar: LlamaGrammar,
209+
accepted_up_to: Option<usize>,
210+
}
211+
212+
impl GrammarStage {
213+
fn apply(
214+
&mut self,
215+
context: *mut llama_context,
216+
tokens: &[Token],
217+
mut candidates_p: llama_token_data_array,
218+
_min_keep: usize,
219+
) -> llama_token_data_array {
220+
// If `accepted_up_to` is `None`, assume that we should start at the end of context.
221+
let accepted_up_to = self.accepted_up_to.unwrap_or(tokens.len());
222+
223+
// Accept all new tokens until the end of context.
224+
for token in &tokens[accepted_up_to..] {
225+
unsafe { llama_grammar_accept_token( B41A context, self.grammar.grammar.as_ptr(), token.0) }
226+
}
227+
self.accepted_up_to = Some(tokens.len());
228+
229+
// Apply grammar sampling to `candidates_p`.
230+
let p_ptr = addr_of_mut!(candidates_p);
231+
unsafe { llama_sample_grammar(context, p_ptr, self.grammar.grammar.as_ptr()) };
232+
233+
candidates_p
234+
}
235+
}
236+
183237
/// Determines how the next token is selected from the distribution produced by
184238
/// the model and the [`SamplerStage`]'s.
185239
#[derive(Clone, Debug)]
@@ -232,7 +286,6 @@ impl TokenSelector {
232286
pub struct StandardSampler {
233287
stages: Vec<SamplerStage>,
234288
min_keep: usize,
235-
grammar: Option<LlamaGrammar>,
236289
token_selector: TokenSelector,
237290
}
238291

@@ -246,12 +299,10 @@ impl StandardSampler {
246299
pub fn new_softmax(
247300
stages: Vec<SamplerStage>,
248301
min_keep: usize,
249-
grammar: Option<LlamaGrammar>,
250302
) -> StandardSampler {
251303
StandardSampler {
252304
stages,
253305
min_keep,
254-
grammar: grammar,
255306
token_selector: TokenSelector::Softmax,
256307
}
257308
}
@@ -262,7 +313,6 @@ impl StandardSampler {
262313
StandardSampler {
263314
stages: Vec::new(),
264315
min_keep: 0,
265-
grammar: None,
266316
token_selector: TokenSelector::Greedy,
267317
}
268318
}
@@ -279,7 +329,6 @@ impl StandardSampler {
279329
StandardSampler {
280330
stages,
281331
min_keep,
282-
grammar: None,
283332
token_selector: TokenSelector::Mirostat {
284333
tau,
285334
eta,
@@ -300,7 +349,6 @@ impl StandardSampler {
300349
StandardSampler {
301350
stages,
302351
min_keep,
303-
grammar: None,
304352
token_selector: TokenSelector::MirostatV2 {
305353
tau,
306354
eta,
@@ -325,7 +373,6 @@ impl Default for StandardSampler {
325373
SamplerStage::MinP(0.05),
326374
SamplerStage::Temperature(0.8),
327375
],
328-
grammar: None,
329376
min_keep: 1,
330377
token_selector: TokenSelector::Softmax,
331378
}
@@ -340,25 +387,12 @@ impl Sampler for StandardSampler {
340387
tokens: &[Token],
341388
mut candidates_p: llama_token_data_array,
342389
) -> Token {
343-
let p_ptr = addr_of_mut!(candidates_p);
344390
let min_keep = self.min_keep.max(1);
345391

346-
// Note: We should sample grammar before applying other sampling stages.
347-
if let Some(grammar) = self.grammar.as_mut() {
348-
unsafe { llama_sample_grammar(context, p_ptr, grammar.grammar.as_ptr()) };
349-
}
350-
351-
for stage in &self.stages {
392+
for stage in &mut self.stages {
352393
candidates_p = stage.apply(context, tokens, candidates_p, min_keep);
353394
}
354395

355-
let token = self.token_selector.select(context, candidates_p);
356-
357-
// Note: We must accept the token into the grammar after sampling if a grammar is provided.
358-
if let Some(grammar) = self.grammar.as_mut() {
359-
unsafe { llama_grammar_accept_token(context, grammar.grammar.as_ptr(), token.0) }
360-
}
361-
362-
token
396+
self.token_selector.select(context, candidates_p)
363397
}
364398
}

0 commit comments

Comments
 (0)
0