@@ -14,12 +14,13 @@ use crate::{grammar::LlamaGrammar, Sampler, Token};
14
14
///
15
15
/// Standard ordering for samplers (taken from [kobold.cpp](https://github.com/LostRuins/koboldcpp)):
16
16
///
17
- /// 1. [`SamplerStage::RepetitionPenalty`]
18
- /// 2. [`SamplerStage::Temperature`], [SamplerStage::DynamicTemperature]
19
- /// 3. [`SamplerStage::TopK`]
20
- /// 4. [`SamplerStage::TailFree`]
21
- /// 5. [`SamplerStage::Typical`]
22
- /// 6. [`SamplerStage::TopP`], [`SamplerStage::MinP`]
17
+ /// 1. [`SamplerStage::Grammar`]
18
+ /// 2. [`SamplerStage::RepetitionPenalty`]
19
+ /// 3. [`SamplerStage::Temperature`], [SamplerStage::DynamicTemperature]
20
+ /// 4. [`SamplerStage::TopK`]
21
+ /// 5. [`SamplerStage::TailFree`]
22
+ /// 6. [`SamplerStage::Typical`]
23
+ /// 7. [`SamplerStage::TopP`], [`SamplerStage::MinP`]
23
24
#[ derive( Clone , Debug ) ]
24
25
#[ non_exhaustive]
25
26
pub enum SamplerStage {
@@ -103,16 +104,34 @@ pub enum SamplerStage {
103
104
///
104
105
/// See: <https://www.trentonbricken.com/Tail-Free-Sampling/>
105
106
TailFree ( f32 ) ,
107
+
108
+ /// A stage that uses a [`LlamaGrammar`] to remove tokens that do not align with a given
109
+ /// grammar. Since this stage has to handle mutable state, an instance of this stage should
110
+ /// only be used in one completion.
111
+ ///
112
+ /// See [`GrammarStage`] and [`LlamaGrammar`] for more information.
113
+ Grammar ( GrammarStage ) ,
106
114
}
107
115
108
116
impl SamplerStage {
117
+ /// Creates a new [`SamplerStage::Grammar`] from a [`LlamaGrammar`].
118
+ ///
119
+ /// `start_position` indicates the token position to begin applying the grammar at. [`None`]
120
+ /// indicates that the grammar begins at the end of context.
121
+ pub fn from_grammar ( grammar : LlamaGrammar , start_position : Option < usize > ) -> Self {
122
+ SamplerStage :: Grammar ( GrammarStage {
123
+ grammar,
124
+ accepted_up_to : start_position,
125
+ } )
126
+ }
127
+
109
128
/// Applies this [`SamplerStage`] to the provided token data array.
110
129
///
111
130
/// Ensures that at least `min_keep` tokens remain after the
112
131
/// [`SamplerStage`]'s are applied.
113
132
#[ allow( clippy:: not_unsafe_ptr_arg_deref) ]
114
133
pub fn apply (
115
- & self ,
134
+ & mut self ,
116
135
context : * mut llama_context ,
117
136
tokens : & [ Token ] ,
118
137
mut candidates_p : llama_token_data_array ,
@@ -173,13 +192,48 @@ impl SamplerStage {
173
192
SamplerStage :: TailFree ( z) => {
174
193
llama_sample_tail_free ( context, p_ptr, * z, min_keep) ;
175
194
}
195
+ SamplerStage :: Grammar ( stage) => {
196
+ candidates_p = stage. apply ( context, tokens, candidates_p, min_keep)
197
+ }
176
198
}
177
199
}
178
200
179
201
candidates_p
180
202
}
181
203
}
182
204
205
+ /// Opaque internals for [`SamplerStage::Grammar`].
206
+ #[ derive( Clone , Debug ) ]
207
+ pub struct GrammarStage {
208
+ grammar : LlamaGrammar ,
209
+ accepted_up_to : Option < usize > ,
210
+ }
211
+
212
+ impl GrammarStage {
213
+ fn apply (
214
+ & mut self ,
215
+ context : * mut llama_context ,
216
+ tokens : & [ Token ] ,
217
+ mut candidates_p : llama_token_data_array ,
218
+ _min_keep : usize ,
219
+ ) -> llama_token_data_array {
220
+ // If `accepted_up_to` is `None`, assume that we should start at the end of context.
221
+ let accepted_up_to = self . accepted_up_to . unwrap_or ( tokens. len ( ) ) ;
222
+
223
+ // Accept all new tokens until the end of context.
224
+ for token in & tokens[ accepted_up_to..] {
225
+ unsafe { llama_grammar_accept_token (
B41A
context, self . grammar . grammar . as_ptr ( ) , token. 0 ) }
226
+ }
227
+ self . accepted_up_to = Some ( tokens. len ( ) ) ;
228
+
229
+ // Apply grammar sampling to `candidates_p`.
230
+ let p_ptr = addr_of_mut ! ( candidates_p) ;
231
+ unsafe { llama_sample_grammar ( context, p_ptr, self . grammar . grammar . as_ptr ( ) ) } ;
232
+
233
+ candidates_p
234
+ }
235
+ }
236
+
183
237
/// Determines how the next token is selected from the distribution produced by
184
238
/// the model and the [`SamplerStage`]'s.
185
239
#[ derive( Clone , Debug ) ]
@@ -232,7 +286,6 @@ impl TokenSelector {
232
286
pub struct StandardSampler {
233
287
stages : Vec < SamplerStage > ,
234
288
min_keep : usize ,
235
- grammar : Option < LlamaGrammar > ,
236
289
token_selector : TokenSelector ,
237
290
}
238
291
@@ -246,12 +299,10 @@ impl StandardSampler {
246
299
pub fn new_softmax (
247
300
stages : Vec < SamplerStage > ,
248
301
min_keep : usize ,
249
- grammar : Option < LlamaGrammar > ,
250
302
) -> StandardSampler {
251
303
StandardSampler {
252
304
stages,
253
305
min_keep,
254
- grammar : grammar,
255
306
token_selector : TokenSelector :: Softmax ,
256
307
}
257
308
}
@@ -262,7 +313,6 @@ impl StandardSampler {
262
313
StandardSampler {
263
314
stages : Vec :: new ( ) ,
264
315
min_keep : 0 ,
265
- grammar : None ,
266
316
token_selector : TokenSelector :: Greedy ,
267
317
}
268
318
}
@@ -279,7 +329,6 @@ impl StandardSampler {
279
329
StandardSampler {
280
330
stages,
281
331
min_keep,
282
- grammar : None ,
283
332
token_selector : TokenSelector :: Mirostat {
284
333
tau,
285
334
eta,
@@ -300,7 +349,6 @@ impl StandardSampler {
300
349
StandardSampler {
301
350
stages,
302
351
min_keep,
303
- grammar : None ,
304
352
token_selector : TokenSelector :: MirostatV2 {
305
353
tau,
306
354
eta,
@@ -325,7 +373,6 @@ impl Default for StandardSampler {
325
373
SamplerStage :: MinP ( 0.05 ) ,
326
374
SamplerStage :: Temperature ( 0.8 ) ,
327
375
] ,
328
- grammar : None ,
329
376
min_keep : 1 ,
330
377
token_selector : TokenSelector :: Softmax ,
331
378
}
@@ -340,25 +387,12 @@ impl Sampler for StandardSampler {
340
387
tokens : & [ Token ] ,
341
388
mut candidates_p : llama_token_data_array ,
342
389
) -> Token {
343
- let p_ptr = addr_of_mut ! ( candidates_p) ;
344
390
let min_keep = self . min_keep . max ( 1 ) ;
345
391
346
- // Note: We should sample grammar before applying other sampling stages.
347
- if let Some ( grammar) = self . grammar . as_mut ( ) {
348
- unsafe { llama_sample_grammar ( context, p_ptr, grammar. grammar . as_ptr ( ) ) } ;
349
- }
350
-
351
- for stage in & self . stages {
392
+ for stage in & mut self . stages {
352
393
candidates_p = stage. apply ( context, tokens, candidates_p, min_keep) ;
353
394
}
354
395
355
- let token = self . token_selector . select ( context, candidates_p) ;
356
-
357
- // Note: We must accept the token into the grammar after sampling if a grammar is provided.
358
- if let Some ( grammar) = self . grammar . as_mut ( ) {
359
- unsafe { llama_grammar_accept_token ( context, grammar. grammar . as_ptr ( ) , token. 0 ) }
360
- }
361
-
362
- token
396
+ self . token_selector . select ( context, candidates_p)
363
397
}
364
398
}
0 commit comments