8000 Use std locks instead of tokio locks · masmullin2000/llama_cpp-rs@09a7aa4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 09a7aa4

Browse files
committed
Use std locks instead of tokio locks
1 parent 9c098a7 commit 09a7aa4

File tree

4 files changed

+27
-121
lines changed

4 files changed

+27
-121
lines changed

crates/llama_cpp/src/detail.rs

Lines changed: 1 addition & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -5,73 +5,14 @@
55
#![allow(non_snake_case)]
66

77
use std::ffi::{c_char, c_void, CStr};
8-
use std::ptr::slice_from_raw_parts;
98

10-
use tokio::sync::mpsc::UnboundedSender;
119
use tracing::{error, info, trace, warn};
1210

1311
use llama_cpp_sys::{
1412
ggml_log_level, ggml_log_level_GGML_LOG_LEVEL_ERROR, ggml_log_level_GGML_LOG_LEVEL_INFO,
15-
ggml_log_level_GGML_LOG_LEVEL_WARN, llama_beams_state,
13+
ggml_log_level_GGML_LOG_LEVEL_WARN,
1614
};
1715

18-
use crate::Token;
19-
20-
pub(crate) struct BeamSearchState {
21-
pub(crate) tx: UnboundedSender<Token>,
22-
}
23-
24-
#[no_mangle]
25-
pub(crate) unsafe extern "C" fn llama_beam_search_callback(
26-
shared_state_ptr: *mut c_void,
27-
beam_state: llama_beams_state,
28-
) {
29-
let shared_state = unsafe {
30-
// SAFETY: `channel` has this type and hasn't been de-allocated.
31-
&mut *(shared_state_ptr as *mut BeamSearchState)
32-
};
33-
34-
if shared_state.tx.is_closed() {
35-
// Close all beams to terminate the search.
36-
for i in 0..beam_state.n_beams {
37-
unsafe {
38-
// SAFETY: beam_views[i] exists where 0 <= i <= n_beams.
39-
*beam_state.beam_views.add(i)
40-
}
41-
.eob = true;
42-
}
43-
}
44-
45-
// Llama.cpp trims the common prefix after every invocation; the presence of
46-
// `common_prefix_length > 0` means the first `common_prefix_length` tokens have been
47-
// settled upon.
48-
if beam_state.common_prefix_length > 0 {
49-
let first_beam = unsafe {
50-
// SAFETY: At least one beam always exists.
51-
&*(beam_state.beam_views)
52-
};
53-
54-
let beam_tokens = unsafe {
55-
// SAFETY: If all beams share a common prefix, at least that many tokens exist in
56-
// every beam.
57-
&*slice_from_raw_parts(first_beam.tokens, beam_state.common_prefix_length)
58-
};
59-
60-
for unshared_token in beam_tokens {
61-
let _ = shared_state.tx.send(Token(*unshared_token));
62-
}
63-
}
64-
65-
if beam_state.last_call {
66-
unsafe {
67-
// SAFETY: `channel` is heap-allocated, and this is the only time we'll construct
68-
// a `Box` back over it; this is the last time this function will be called, and
69-
// the last time this pointer will be seen.
70-
let _ = Box::from_raw(shared_state);
71-
}
72-
}
73-
}
74-
7516
#[no_mangle]
7617
pub(crate) unsafe extern "C" fn llama_log_callback(
7718
level: ggml_log_level,

crates/llama_cpp/src/model/mod.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,12 @@ use std::cmp::min;
55
use std::ffi::{c_char, CStr, CString};
66
use std::path::{Path, PathBuf};
77
use std::ptr::slice_from_raw_parts;
8-
use std::sync::{atomic::AtomicUsize, Arc};
8+
use std::sync::{atomic::AtomicUsize, Arc, Mutex, RwLock};
99
use std::usize;
1010

1111
use derive_more::{Deref, DerefMut};
1212
use futures::executor::block_on;
1313
use thiserror::Error;
14-
use tokio::sync::Mutex;
15-
use tokio::sync::RwLock;
1614
use tracing::{error, info, trace, warn};
1715

1816
use backend::BackendRef;

crates/llama_cpp/src/session/mod.rs

Lines changed: 24 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,23 @@
11
//! Functionality for the [`LlamaSession`] struct
22
33
use std::cmp::min;
4-
use std::ffi::c_void;
54
use std::ops::{Bound, RangeBounds};
65
use std::sync::atomic::{AtomicUsize, Ordering};
7-
use std::sync::Arc;
6+
use std::sync::{Arc, Mutex, RwLock};
87
use std::thread;
98

10-
use futures::executor::block_on;
119
use thiserror::Error;
12-
use tokio::sync::{mpsc::unbounded_channel, Mutex, RwLock};
10+
use tokio::sync::mpsc::unbounded_channel;
1311
use tracing::{error, info, trace, warn};
1412

1513
use llama_cpp_sys::{
16-
llama_beam_search, llama_context, llama_copy_state_data, llama_decode, llama_free,
14+
llama_context, llama_copy_state_data, llama_decode, llama_free,
1715
llama_get_logits_ith, llama_get_state_size, llama_kv_cache_seq_rm, llama_set_state_data,
1816
llama_token_data, llama_token_data_array,
1917
};
2018

21-
use crate::{detail, LlamaModel, LlamaTokenizationError, Sampler, Token};
19+
use crate::standard_sampler::StandardSampler;
20+
use crate::{LlamaModel, LlamaTokenizationError, Sampler, Token};
2221

2322
mod completion;
2423
mod params;
@@ -172,7 +171,7 @@ impl LlamaSession {
172171
let err = unsafe {
173172
// SAFETY: `llama_decode` will not fail for a valid `batch`, which we correctly
174173
// initialized above.
175-
llama_decode(block_on(self.inner.ctx.lock()).ptr, batch.handle())
174+
llama_decode(self.inner.ctx.lock().unwrap().ptr, batch.handle())
176175
};
177176
if err != 0 {
178177
return Err(LlamaContextError::DecodeFailed(err));
@@ -182,7 +181,7 @@ impl LlamaSession {
182181
last_batch_size = sequence.len();
183182
}
184183

185-
block_on(self.inner.tokens.write()).extend_from_slice(tokens);
184+
self.inner.tokens.write().unwrap().extend_from_slice(tokens);
186185

187186
self.inner
188187
.last_batch_size
@@ -236,34 +235,13 @@ impl LlamaSession {
236235
.unwrap()
237236
}
238237

239-
/// Starts generating tokens at the end of the context using llama.cpp's built-in Beam search.
240-
/// TODO fix: beam search keeps going even after it should have ended
238+
/// Starts generating tokens at the end of the context using a greedy
239+
/// sampler
241240
pub fn start_completing(&mut self) -> CompletionHandle {
242-
let (tx, rx) = unbounded_channel();
243-
let history_size = self.context_size();
244-
let session = self.clone();
245-
246-
info!("Generating completions with {history_size} tokens of history");
247-
248-
thread::spawn(move || unsafe {
249-
let state = Box::new(detail::BeamSearchState { tx });
250-
// SAFETY: `state_ptr` is converted back to a [`Box`] and freed in [`detail::llama_beam_search_callback`]
251-
let state_ptr = Box::into_raw(state);
252-
253-
llama_beam_search(
254-
block_on(session.inner.ctx.lock()).ptr,
255-
Some(detail::llama_beam_search_callback),
256-
state_ptr as *mut _ as *mut c_void,
257-
1,
258-
history_size as i32,
259-
32_768,
260-
);
261-
});
262-
263-
CompletionHandle {
264-
rx,
265-
model: self.model(),
266-
}
241+
self.start_completing_with(
242+
StandardSampler::new_greedy(),
243+
self.params().n_ctx as usize - self.context_size(),
244+
)
267245
}
268246

269247
/// Start completion.
@@ -282,10 +260,10 @@ impl LlamaSession {
282260
info!("Generating completions with {history_size} tokens of history");
283261

284262
thread::spawn(move || {
285-
let context = block_on(session.inner.ctx.lock());
263+
let context = session.inner.ctx.lock().unwrap();
286264
let vocab = session.model().vocabulary_size();
287265
let end_of_stream = session.model().eos();
288-
let mut token_buf = block_on(session.inner.tokens.write());
266+
let mut token_buf = session.inner.tokens.write().unwrap();
289267
let mut count = 0;
290268
let mut batch = Batch::new(1, 0, 1);
291269
let mut i = session.inner.last_batch_size.load(Ordering::SeqCst);
@@ -366,12 +344,12 @@ impl LlamaSession {
366344

367345
/// Returns the number of tokens currently in this session's context
368346
pub fn context_size(&self) -> usize {
369-
block_on(self.inner.tokens.read()).len()
347+
self.inner.tokens.read().unwrap().len()
370348
}
371349

372350
/// Returns the list of tokens in the current context
373351
pub fn context(&self) -> Vec<Token> {
374-
block_on(self.inner.tokens.read()).clone()
352+
self.inner.tokens.read().unwrap().clone()
375353
}
376354

377355
/// Removes all tokens within the given range without perform C95D ing any prompt
@@ -393,12 +371,12 @@ impl LlamaSession {
393371
Bound::Unbounded => -1,
394372
};
395373

396-
let context = block_on(self.inner.ctx.lock());
374+
let context = self.inner.ctx.lock().unwrap();
397375

398376
// -1 here to match all sequences
399377
unsafe { llama_kv_cache_seq_rm(context.ptr, -1, start_bound, end_bound) }
400378

401-
block_on(self.inner.tokens.write()).drain(range);
379+
self.inner.tokens.write().unwrap().drain(range);
402380
}
403381

404382
/// Removes all but the first `n_tokens` tokens from the context.
@@ -415,7 +393,7 @@ impl LlamaSession {
415393
new_tokens: impl AsRef<[Token]>,
416394
) -> Result<(), LlamaContextError> {
417395
let new_tokens = new_tokens.as_ref();
418-
let old_tokens = block_on(self.inner.tokens.read());
396+
let old_tokens = self.inner.tokens.read().unwrap();
419397

420398
let shared_prefix = old_tokens
421399
.iter()
@@ -480,7 +458,7 @@ impl LlamaSession {
480458
/// This differs from [`LlamaSession::clone`] in that [`LlamaSession::clone`] creates a new
481459
/// reference to the same underlying [`LlamaSession`].
482460
pub fn deep_copy(&self) -> Result<LlamaSession, LlamaContextError> {
483-
let ctx = self.inner.ctx.blocking_lock();
461+
let ctx = self.inner.ctx.lock().unwrap();
484462

485463
#[allow(unused_mut)]
486464
let mut copy = self.model().create_session(self.inner.params.clone())?;
@@ -496,13 +474,13 @@ impl LlamaSession {
496474
let copy_size = llama_copy_state_data(ctx.ptr, buf.as_mut_ptr());
497475
assert!(copy_size <= size);
498476
let set_size =
499-
llama_set_state_data(copy.inner.ctx.blocking_lock().ptr, buf.as_mut_ptr());
477+
llama_set_state_data(copy.inner.ctx.lock().unwrap().ptr, buf.as_mut_ptr());
500478
assert_eq!(copy_size, set_size);
501479
}
502480

503481
// NOTE: Any changes to the fields of a LlamaSession may require that
504482
// those changes are mirrored here
505-
*block_on(copy.inner.tokens.write()) = block_on(self.inner.tokens.read()).clone();
483+
*copy.inner.tokens.write().unwrap() = self.inner.tokens.read().unwrap().clone();
506484
copy.inner.last_batch_size.store(
507485
self.inner.last_batch_size.load(Ordering::SeqCst),
508486
Ordering::SeqCst,
@@ -512,16 +490,8 @@ impl LlamaSession {
512490
}
513491

514492
/// Returns the maximum size in bytes this session is occupying in memory.
515-
///
516-
/// This function may **NOT*** be called in async environments, for an async version see [`async_memory_size`].
517493
pub fn memory_size(&self) -> usize {
518-
let ctx = self.inner.ctx.blocking_lock();
519-
unsafe { llama_get_state_size(ctx.ptr) }
520-
}
521-
522-
/// Asynchronously returns the maximum size in bytes this session is occupying in memory.
523-
pub async fn async_memory_size(&self) -> usize {
524-
let ctx = self.inner.ctx.lock().await;
494+
let ctx = self.inner.ctx.lock().unwrap();
525495
unsafe { llama_get_state_size(ctx.ptr) }
526496
}
527497
}

crates/llama_cpp/src/standard_sampler.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -296,10 +296,7 @@ impl StandardSampler {
296296
///
297297
/// Ensures that at least `min_keep` tokens remain after the
298298
/// [`SamplerStage`]'s are applied.
299-
pub fn new_softmax(
300-
stages: Vec<SamplerStage>,
301-
min_keep: usize,
302-
) -> StandardSampler {
299+
pub fn new_softmax(stages: Vec<SamplerStage>, min_keep: usize) -> StandardSampler {
303300
StandardSampler {
304301
stages,
305302
min_keep,

0 commit comments

Comments
 (0)
0