@@ -198,6 +198,34 @@ enum TokenSelector {
198
198
MirostatV2 { tau : f32 , eta : f32 , mu : f32 } ,
199
199
}
200
200
201
+ impl TokenSelector {
202
+ /// Select and and return a token from a given distribution.
203
+ ///
204
+ /// Note: while this function may take a mutable reference to `self`, the internal state *shouldn't* be altered.
205
+ #[ allow( clippy:: not_unsafe_ptr_arg_deref) ]
206
+ pub fn select (
207
+ & mut self ,
208
+ context : * mut llama_context ,
209
+ mut candidates_p : llama_token_data_array ,
210
+ ) -> Token {
211
+ unsafe {
212
+ let p_ptr = addr_of_mut ! ( candidates_p) ;
213
+ let id = match self {
214
+ TokenSelector :: Softmax => llama_sample_token ( context, p_ptr) ,
215
+ TokenSelector :: Greedy => llama_sample_token_greedy ( context, p_ptr) ,
216
+ TokenSelector :: Mirostat { tau, eta, m, mu } => {
217
+ llama_sample_token_mirostat ( context, p_ptr, * tau, * eta, * m, addr_of_mut ! ( * mu) )
218
+ }
219
+ TokenSelector :: MirostatV2 { tau, eta, mu } => {
220
+ llama_sample_token_mirostat_v2 ( context, p_ptr, * tau, * eta, addr_of_mut ! ( * mu) )
221
+ }
222
+ } ;
223
+
224
+ Token ( id)
225
+ }
226
+ }
227
+ }
228
+
201
229
/// Selects a token after applying multiple [`SamplerStage`]'s to the
202
230
/// probability distribution output by the model.
203
231
#[ derive( Clone , Debug ) ]
@@ -308,20 +336,6 @@ impl Sampler for StandardSampler {
308
336
candidates_p = stage. apply ( context, tokens, candidates_p, min_keep) ;
309
337
}
310
338
311
- unsafe {
312
- let p_ptr = addr_of_mut ! ( candidates_p) ;
313
- let id = match & mut self . token_selector {
314
- TokenSelector :: Softmax => llama_sample_token ( context, p_ptr) ,
315
- TokenSelector :: Greedy => llama_sample_token_greedy ( context, p_ptr) ,
316
- TokenSelector :: Mirostat { tau, eta, m, mu } => {
317
- llama_sample_token_mirostat ( context, p_ptr, * tau, * eta, * m, addr_of_mut ! ( * mu) )
318
- }
319
- TokenSelector :: MirostatV2 { tau, eta, mu } => {
320
- llama_sample_token_mirostat_v2 ( context, p_ptr, * tau, * eta, addr_of_mut ! ( * mu) )
321
- }
322
- } ;
323
-
324
- Token ( id)
325
- }
339
+ self . token_selector . select ( context, candidates_p)
326
340
}
327
341
}
0 commit comments