8000 Move logic for applying SamplerStages into it's own method so it can … · masmullin2000/llama_cpp-rs@ff8c829 · GitHub
[go: up one dir, main page]

Skip to content

Commit ff8c829

Browse files
committed
Move logic for applying SamplerStages into it's own method so it can be reused in other samplers
1 parent 261f397 commit ff8c829

File tree

1 file changed

+80
-57
lines changed

1 file changed

+80
-57
lines changed

crates/llama_cpp/src/standard_sampler.rs

Lines changed: 80 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,81 @@ pub enum SamplerStage {
105105
TailFree(f32),
106106
}
107107

108+
impl SamplerStage {
109+
/// Applies this [`SamplerStage`] to the provided token data array.
110+
///
111+
/// Ensures that at least `min_keep` tokens remain after the
112+
/// [`SamplerStage`]'s are applied.
113+
#[allow(clippy::not_unsafe_ptr_arg_deref)]
114+
pub fn apply(
115+
&self,
116+
context: *mut llama_context,
117+
tokens: &[Token],
118+
mut candidates_p: llama_token_data_array,
119+
min_keep: usize,
120+
) -> llama_token_data_array {
121+
let p_ptr = addr_of_mut!(candidates_p);
122+
123+
unsafe {
124+
match self {
125+
SamplerStage::RepetitionPenalty {
126+
repetition_penalty,
127+
frequency_penalty,
128+
presence_penalty,
129+
last_n,
130+
} => {
131+
let last_n = if *last_n < 0 {
132+
tokens.len()
133+
} else {
134+
tokens.len().min(*last_n as usize)
135+
};
136+
137+
llama_sample_repetition_penalties(
138+
context,
139+
p_ptr,
140+
tokens[tokens.len() - last_n..].as_ptr() as *const llama_token,
141+
last_n,
142+
*repetition_penalty,
143+
*frequency_penalty,
144+
*presence_penalty,
145+
);
146+
}
147+
SamplerStage::Temperature(temp) => {
148+
if *temp == 0.0 {
149+
llama_sample_top_k(context, p_ptr, 1, 1);
150+
} else {
151+
llama_sample_temp(context, p_ptr, *temp);
152+
}
153+
}
154+
SamplerStage::DynamicTemperature {
155+
min_temp,
156+
max_temp,
157+
exponent_val,
158+
} => {
159+
llama_sample_entropy(context, p_ptr, *min_temp, *max_temp, *exponent_val);
160+
}
161+
SamplerStage::TopP(top_p) => {
162+
llama_sample_top_p(context, p_ptr, *top_p, min_keep);
163+
}
164+
SamplerStage::MinP(min_p) => {
165+
llama_sample_min_p(context, p_ptr, *min_p, min_keep);
166+
}
167+
SamplerStage::TopK(top_k) => {
168+
llama_sample_top_k(context, p_ptr, *top_k, min_keep);
169+
}
170+
SamplerStage::Typical(p) => {
171+
llama_sample_typical(context, p_ptr, *p, min_keep);
172+
}
173+
SamplerStage::TailFree(z) => {
174+
llama_sample_tail_free(context, p_ptr, *z, min_keep);
175+
}
176+
}
177+
}
178+
179+
candidates_p
180+
}
181+
}
182+
108183
/// Determines how the next token is selected from the distribution produced by
109184
/// the model and the [`SamplerStage`]'s.
110185
#[derive(Clone, Debug)]
@@ -227,66 +302,14 @@ impl Sampler for StandardSampler {
227302
tokens: &[Token],
228303
mut candidates_p: llama_token_data_array,
229304
) -> Token {
230-
let p_ptr = addr_of_mut!(candidates_p);
231305
let min_keep = self.min_keep.max(1);
232306

233-
unsafe {
234-
for stage in &self.stages {
235-
match stage {
236-
SamplerStage::RepetitionPenalty {
237-
repetition_penalty,
238-
frequency_penalty,
239-
presence_penalty,
240-
last_n,
241-
} => {
242-
let last_n = if *last_n < 0 {
243-
tokens.len()
244-
} else {
245-
tokens.len().min(*last_n as usize)
246-
};
247-
248-
llama_sample_repetition_penalties(
249-
context,
250-
p_ptr,
251-
tokens[tokens.len() - last_n..].as_ptr() as *const llama_token,
252-
last_n,
253-
*repetition_penalty,
254-
*frequency_penalty,
255-
*presence_penalty,
256-
);
257-
}
258-
SamplerStage::Temperature(temp) => {
259-
if *temp == 0.0 {
260-
llama_sample_top_k(context, p_ptr, 1, 1);
261-
} else {
262-
llama_sample_temp(context, p_ptr, *temp);
263-
}
264-
}
265-
SamplerStage::DynamicTemperature {
266-
min_temp,
267-
max_temp,
268-
exponent_val,
269-
} => {
270-
llama_sample_entropy(context, p_ptr, *min_temp, *max_temp, *exponent_val);
271-
}
272-
SamplerStage::TopP(top_p) => {
273-
llama_sample_top_p(context, p_ptr, *top_p, min_keep);
274-
}
275-
SamplerStage::MinP(min_p) => {
276-
llama_sample_min_p(context, p_ptr, *min_p, min_keep);
277-
}
278-
SamplerStage::TopK(top_k) => {
279-
llama_sample_top_k(context, p_ptr, *top_k, min_keep);
280-
}
281-
SamplerStage::Typical(p) => {
282-
llama_sample_typical(context, p_ptr, *p, min_keep);
283-
}
284-
SamplerStage::TailFree(z) => {
285-
llama_sample_tail_free(context, p_ptr, *z, min_keep);
286-
}
287-
}
288-
}
307+
for stage in &self.stages {
308+
candidates_p = stage.apply(context, tokens, candidates_p, min_keep);
309+
}
289310

311+
unsafe {
312+
let p_ptr = addr_of_mut!(candidates_p);
290313
let id = match &mut self.token_selector {
291314
TokenSelector::Softmax => llama_sample_token(context, p_ptr),
292315
TokenSelector::Greedy => llama_sample_token_greedy(context, p_ptr),

0 commit comments

Comments
 (0)
0