@@ -105,13 +105,25 @@ pub enum SamplerStage {
105
105
TailFree ( f32 ) ,
106
106
107
107
/// A stage that uses a [`LlamaGrammar`] to remove tokens that do not align with a given
108
- /// grammar.
108
+ /// grammar. Since this stage has to handle mutable state, an instance of this stage should
109
+ /// only be used in one completion.
109
110
///
110
111
/// See [`GrammarStage`] and [`LlamaGrammar`] for more information.
111
112
Grammar ( GrammarStage ) ,
112
113
}
113
114
114
115
impl SamplerStage {
116
+ /// Creates a new [`SamplerStage::Grammar`] from a [`LlamaGrammar`].
117
+ ///
118
+ /// `start_position` indicates the token position to begin applying the grammar at. [`None`]
119
+ /// indicates that the grammar begins at the end of context.
120
+ pub fn from_grammar ( grammar : LlamaGrammar , start_position : Option < usize > ) -> Self {
121
+ SamplerStage :: Grammar ( GrammarStage {
122
+ grammar,
123
+ accepted_to : start_position,
124
+ } )
125
+ }
126
+
115
127
/// Applies this [`SamplerStage`] to the provided token data array.
116
128
///
117
129
/// Ensures that at least `min_keep` tokens remain after the
@@ -192,45 +204,23 @@ impl SamplerStage {
192
204
/// Opaque internals for [`SamplerStage::Grammar`].
193
205
#[ derive( Clone , Debug ) ]
194
206
pub struct GrammarStage {
195
- original_grammar : LlamaGrammar ,
196
207
grammar : LlamaGrammar ,
197
- tokens : Vec < Token > ,
208
+ accepted_to : Option < usize > ,
198
209
}
199
210
200
211
impl GrammarStage {
201
- /// Creates a new [`GrammarStage`] from a [`LlamaGrammar`]
202
- pub fn new ( grammar : LlamaGrammar ) -> Self {
203
- Self {
204
- original_grammar : grammar. clone ( ) ,
205
- grammar,
206
- tokens : Vec :: new ( )
207
- }
208
- }
209
-
210
- /// Creates a new [`SamplerStage::Grammar`] from a [`LlamaGrammar`]
211
- pub fn new_stage ( grammar : LlamaGrammar ) -> SamplerStage {
212
- SamplerStage :: Grammar ( Self :: new ( grammar) )
213
- }
214
-
215
212
fn apply (
216
213
& mut self ,
217
214
context : * mut llama_context ,
218
215
tokens : & [ Token ] ,
219
216
mut candidates_p : llama_token_data_array ,
220
217
_min_keep : usize ,
221
218
) {
222
- let new_tokens = if let Some ( suffix) = tokens. strip_prefix ( self . tokens . as_slice ( ) ) {
223
- suffix
224
- } else {
225
- self . tokens . clear ( ) ;
226
- self . grammar = self . original_grammar . clone ( ) ;
227
- tokens
228
- } ;
229
-
230
- for token in new_tokens {
219
+ let accepted_to = self . accepted_to . unwrap_or ( tokens. len ( ) ) ;
220
+ for token in & tokens[ accepted_to..] {
231
221
unsafe { llama_grammar_accept_token ( context, self . grammar . grammar . as_ptr ( ) , token. 0 ) }
232
222
}
233
- self . tokens . extend_from_slice ( new_tokens ) ;
223
+ self . accepted_to = Some ( tokens. len ( ) ) ;
234
224
235
225
let p_ptr = addr_of_mut ! ( candidates_p) ;
236
226
unsafe { llama_sample_grammar ( context, p_ptr, self . grammar . grammar . as_ptr ( ) ) } ;
0 commit comments