10000 moved token selection to its own function, following SamplerStage's d… · masmullin2000/llama_cpp-rs@17dd65f · GitHub
[go: up one dir, main page]

Skip to content

Commit 17dd65f

Browse files
committed
moved token selection to its own function, following SamplerStage's design
1 parent 0ab877f commit 17dd65f

File tree

1 file changed

+29
-15
lines changed

1 file changed

+29
-15
lines changed

crates/llama_cpp/src/standard_sampler.rs

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,34 @@ enum TokenSelector {
198198
MirostatV2 { tau: f32, eta: f32, mu: f32 },
199199
}
200200

201+
impl TokenSelector {
202+
/// Select and and return a token from a given distribution.
203+
///
204+
/// Note: while this function may take a mutable reference to `self`, the internal state *shouldn't* be altered.
205+
#[allow(clippy::not_unsafe_ptr_arg_deref)]
206+
pub fn select(
207+
&mut self,
208+
context: *mut llama_context,
209+
mut candidates_p: llama_token_data_array,
210+
) -> Token {
211+
unsafe {
212+
let p_ptr = addr_of_mut!(candidates_p);
213+
let id = match self {
214+
TokenSelector::Softmax => llama_sample_token(context, p_ptr),
215+
TokenSelector::Greedy => llama_sample_token_greedy(context, p_ptr),
216+
TokenSelector::Mirostat { tau, eta, m, mu } => {
217+
llama_sample_token_mirostat(context, p_ptr, *tau, *eta, *m, addr_of_mut!(*mu))
218+
}
219+
TokenSelector::MirostatV2 { tau, eta, mu } => {
220+
llama_sample_token_mirostat_v2(context, p_ptr, *tau, *eta, addr_of_mut!(*mu))
221+
}
222+
};
223+
224+
Token(id)
225+
}
226+
}
227+
}
228+
201229
/// Selects a token after applying multiple [`SamplerStage`]'s to the
202230
/// probability distribution output by the model.
203231
#[derive(Clone, Debug)]
@@ -308,20 +336,6 @@ impl Sampler for StandardSampler {
308336
candidates_p = stage.apply(context, tokens, candidates_p, min_keep);
309337
}
310338

311-
unsafe {
312-
let p_ptr = addr_of_mut!(candidates_p);
313-
let id = match &mut self.token_selector {
314-
TokenSelector::Softmax => llama_sample_token(context, p_ptr),
315-
TokenSelector::Greedy => llama_sample_token_greedy(context, p_ptr),
316-
TokenSelector::Mirostat { tau, eta, m, mu } => {
317-
llama_sample_token_mirostat(context, p_ptr, *tau, *eta, *m, addr_of_mut!(*mu))
318-
}
319-
TokenSelector::MirostatV2 { tau, eta, mu } => {
320-
llama_sample_token_mirostat_v2(context, p_ptr, *tau, *eta, addr_of_mut!(*mu))
321-
}
322-
};
323-
324-
Token(id)
325-
}
339+
self.token_selector.select(context, candidates_p)
326340
}
327341
}

0 commit comments

Comments
 (0)
0