@@ -78,7 +78,7 @@ use std::ffi::{c_void, CStr, CString};
78
78
use std:: path:: { Path , PathBuf } ;
79
79
use std:: sync:: Arc ;
80
80
use std:: { ptr, thread} ;
81
- use tokio:: sync:: RwLock ;
81
+ use tokio:: sync:: { Mutex , RwLock } ;
82
82
83
83
use ctor:: { ctor, dtor} ;
84
84
use derive_more:: { Deref , DerefMut } ;
@@ -388,7 +388,7 @@ impl LlamaModel {
388
388
389
389
LlamaSession {
390
390
model : self . clone ( ) ,
391
- inner : Arc :: new ( LlamaContextInner { ptr : ctx } ) ,
391
+ inner : Arc :: new ( Mutex :: new ( LlamaContextInner { ptr : ctx } ) ) ,
392
392
history_size : 0 ,
393
393
}
394
394
}
@@ -460,7 +460,7 @@ pub struct LlamaSession {
460
460
model : LlamaModel ,
461
461
462
462
/// A pointer to the llama.cpp side of the model context.
463
- inner : Arc < LlamaContextInner > ,
463
+ inner : Arc < Mutex < LlamaContextInner > > ,
464
464
465
465
/// The number of tokens present in this model's context.
466
466
history_size : usize ,
@@ -553,7 +553,7 @@ impl LlamaSession {
553
553
if unsafe {
554
554
// SAFETY: `llama_decode` will not fail for a valid `batch`, which we correctly
555
555
// initialized above.
556
- llama_decode ( self . inner . ptr , batch)
556
+ llama_decode ( self . inner . blocking_lock ( ) . ptr , batch)
557
557
} != 0
558
558
{
559
559
return Err ( LlamaInternalError . into ( ) ) ;
@@ -592,12 +592,12 @@ impl LlamaSession {
592
592
self . history_size,
593
593
) ;
594
594
595
- let inner = self . inner . clone ( ) ;
596
595
let past_tokens = self . history_size ;
596
+ let mutex = self . inner . clone ( ) ;
597
597
598
598
thread:: spawn ( move || unsafe {
599
599
llama_beam_search (
600
- inner . ptr ,
600
+ mutex . blocking_lock ( ) . ptr ,
601
601
Some ( detail:: llama_beam_search_callback) ,
602
602
Box :: leak ( Box :: new ( detail:: BeamSearchState { tx } ) ) as * mut _ as * mut c_void ,
603
603
1 ,
@@ -690,6 +690,7 @@ mod detail {
690
690
691
691
use std:: ffi:: { c_char, c_void, CStr } ;
692
692
use std:: ptr:: slice_from_raw_parts;
693
+ use tokio:: sync:: OwnedSemaphorePermit ;
693
694
694
695
use tracing:: { error, info, trace, warn} ;
695
696
0 commit comments