8000 Fix bugs; Simplify api · masmullin2000/llama_cpp-rs@62ee326 · GitHub
[go: up one dir, main page]

Skip to content

Commit 62ee326

Browse files
committed
Fix bugs; Simplify api
1 parent 65238d3 commit 62ee326

File tree

2 files changed

+18
-28
lines changed

2 files changed

+18
-28
lines changed

crates/llama_cpp/src/session/params.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ pub struct SessionParams {
105105
pub pooling: PoolingType,
106106

107107
/// defragment the KV cache if holes/size > thold, < 0 disabled (default)
108-
defrag_threshold: f32,
108+
pub defrag_threshold: f32,
109109
}
110110

111111
impl Default for SessionParams {

crates/llama_cpp/src/standard_sampler.rs

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,25 @@ pub enum SamplerStage {
105105
TailFree(f32),
106106

107107
/// A stage that uses a [`LlamaGrammar`] to remove tokens that do not align with a given
108-
/// grammar.
108+
/// grammar. Since this stage has to handle mutable state, an instance of this stage should
109+
/// only be used in one completion.
109110
///
110111
/// See [`GrammarStage`] and [`LlamaGrammar`] for more information.
111112
Grammar(GrammarStage),
112113
}
113114

114115
impl SamplerStage {
116+
/// Creates a new [`SamplerStage::Grammar`] from a [`LlamaGrammar`].
117+
///
118+
/// `start_position` indicates the token position to begin applying the grammar at. [`None`]
119+
/// indicates that the grammar begins at the end of context.
120+
pub fn from_grammar(grammar: LlamaGrammar, start_position: Option<usize>) -> Self {
121+
SamplerStage::Grammar(GrammarStage {
122+
grammar,
123+
accepted_to: start_position,
124+
})
125+
}
126+
115127
/// Applies this [`SamplerStage`] to the provided token data array.
116128
///
117129
/// Ensures that at least `min_keep` tokens remain after the
@@ -192,45 +204,23 @@ impl SamplerStage {
192204
/// Opaque internals for [`SamplerStage::Grammar`].
193205
#[derive(Clone, Debug)]
194206
pub struct GrammarStage {
195-
original_grammar: LlamaGrammar,
196207
grammar: LlamaGrammar,
197-
tokens: Vec<Token>,
208+
accepted_to: Option<usize>,
198209
}
199210

200211
impl GrammarStage {
201-
/// Creates a new [`GrammarStage`] from a [`LlamaGrammar`]
202-
pub fn new(grammar: LlamaGrammar) -> Self {
203-
Self {
204-
original_grammar: grammar.clone(),
205-
grammar,
206-
tokens: Vec::new()
207-
}
208-
}
209-
210-
/// Creates a new [`SamplerStage::Grammar`] from a [`LlamaGrammar`]
211-
pub fn new_stage(grammar: LlamaGrammar) -> SamplerStage {
212-
SamplerStage::Grammar(Self::new(grammar))
213-
}
214-
215212
fn apply(
216213
&mut self,
217214
context: *mut llama_context,
218215
tokens: &[Token],
219216
mut candidates_p: llama_token_data_array,
220217
_min_keep: usize,
221218
) {
222-
let new_tokens = if let Some(suffix) = tokens.strip_prefix(self.tokens.as_slice()) {
223-
suffix
224-
} else {
225-
self.tokens.clear();
226-
self.grammar = self.original_grammar.clone();
227-
tokens
228-
};
229-
230-
for token in new_tokens {
219+
let accepted_to = self.accepted_to.unwrap_or(tokens.len());
220+
for token in &tokens[accepted_to..] {
231221
unsafe { llama_grammar_accept_token(context, self.grammar.grammar.as_ptr(), token.0) }
232222
}
233-
self.tokens.extend_from_slice(new_tokens);
223+
self.accepted_to = Some(tokens.len());
234224

235225
let p_ptr = addr_of_mut!(candidates_p);
236226
unsafe { llama_sample_grammar(context, p_ptr, self.grammar.grammar.as_ptr()) };

0 commit comments

Comments
 (0)
0