8000 fix: require `llama_context` is accessed from behind a mutex · edgenai/llama_cpp-rs@81e5de9 · GitHub
[go: up one dir, main page]

Skip to content

Commit 81e5de9

Browse files
committed
fix: require llama_context is accessed from behind a mutex
This solves a race condition when several `get_completions` threads are spawned at the same time
1 parent 27706de commit 81e5de9

File tree

3 files changed

+49
-8
lines changed

3 files changed

+49
-8
lines changed

Cargo.lock

Lines changed: 41 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/llama_cpp/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ derive_more = "0.99.17"
1919
flume = "0.11.0"
2020
llama_cpp_sys = { version = "^0.2.0", path = "../llama_cpp_sys" }
2121
num_cpus = "1.16.0"
22+
parking_lot = "0.12.1"
2223
thiserror = "1.0.49"
2324
tokio = { version = "1.33.0", features = ["sync"] }
2425
tracing = "0.1.39"

crates/llama_cpp/src/lib.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ use std::ffi::{c_void, CStr, CString};
7878
use std::path::{Path, PathBuf};
7979
use std::sync::Arc;
8080
use std::{ptr, thread};
81-
use tokio::sync::RwLock;
81+
use tokio::sync::{Mutex, RwLock};
8282

8383
use ctor::{ctor, dtor};
8484
use derive_more::{Deref, DerefMut};
@@ -388,7 +388,7 @@ impl LlamaModel {
388388

389389
LlamaSession {
390390
model: self.clone(),
391-
inner: Arc::new(LlamaContextInner { ptr: ctx }),
391+
inner: Arc::new(Mutex::new(LlamaContextInner { ptr: ctx }) ),
392392
history_size: 0,
393393
}
394394
}
@@ -460,7 +460,7 @@ pub struct LlamaSession {
460460
model: LlamaModel,
461461

462462
/// A pointer to the llama.cpp side of the model context.
463-
inner: Arc<LlamaContextInner>,
463+
inner: Arc<Mutex<LlamaContextInner>>,
464464

465465
/// The number of tokens present in this model's context.
466466
history_size: usize,
@@ -553,7 +553,7 @@ impl LlamaSession {
553553
if unsafe {
554554
// SAFETY: `llama_decode` will not fail for a valid `batch`, which we correctly
555555
// initialized above.
556-
llama_decode(self.inner.ptr, batch)
556+
llama_decode(self.inner.blocking_lock().ptr, batch)
557557
} != 0
558558
{
559559
return Err(LlamaInternalError.into());
@@ -592,12 +592,12 @@ impl LlamaSession {
592592
self.history_size,
593593
);
594594

595-
let inner = self.inner.clone();
596595
let past_tokens = self.history_size;
596+
let mutex = self.inner.clone();
597597

598598
thread::spawn(move || unsafe {
599599
llama_beam_search(
600-
inner.ptr,
600+
mutex.blocking_lock().ptr,
601601
Some(detail::llama_beam_search_callback),
602602
Box::leak(Box::new(detail::BeamSearchState { tx })) as *mut _ as *mut c_void,
603603
1,
@@ -690,6 +690,7 @@ mod detail {
690690

691691
use std::ffi::{c_char, c_void, CStr};
692692
use std::ptr::slice_from_raw_parts;
693+
use tokio::sync::OwnedSemaphorePermit;
693694

694695
use tracing::{error, info, trace, warn};
695696

0 commit comments

Comments
 (0)
0