1
1
//! Functionality for the [`LlamaSession`] struct
2
2
3
3
use std:: cmp:: min;
4
- use std:: ffi:: c_void;
5
4
use std:: ops:: { Bound , RangeBounds } ;
6
5
use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
7
- use std:: sync:: Arc ;
6
+ use std:: sync:: { Arc , Mutex , RwLock } ;
8
7
use std:: thread;
9
8
10
- use futures:: executor:: block_on;
11
9
use thiserror:: Error ;
12
- use tokio:: sync:: { mpsc:: unbounded_channel, Mutex , RwLock } ;
10
+ use tokio:: sync:: mpsc:: unbounded_channel;
13
11
use tracing:: { error, info, trace, warn} ;
14
12
15
13
use llama_cpp_sys:: {
16
- llama_beam_search , llama_context, llama_copy_state_data, llama_decode, llama_free,
14
+ llama_context, llama_copy_state_data, llama_decode, llama_free,
17
15
llama_get_logits_ith, llama_get_state_size, llama_kv_cache_seq_rm, llama_set_state_data,
18
16
llama_token_data, llama_token_data_array,
19
17
} ;
20
18
21
- use crate :: { detail, LlamaModel , LlamaTokenizationError , Sampler , Token } ;
19
+ use crate :: standard_sampler:: StandardSampler ;
20
+ use crate :: { LlamaModel , LlamaTokenizationError , Sampler , Token } ;
22
21
23
22
mod completion;
24
23
mod params;
@@ -172,7 +171,7 @@ impl LlamaSession {
172
171
let err = unsafe {
173
172
// SAFETY: `llama_decode` will not fail for a valid `batch`, which we correctly
174
173
// initialized above.
175
- llama_decode ( block_on( self . inner . ctx . lock ( ) ) . ptr , batch. handle ( ) )
174
+ llama_decode ( self . inner . ctx . lock ( ) . unwrap ( ) . ptr , batch. handle ( ) )
176
175
} ;
177
176
if err != 0 {
178
177
return Err ( LlamaContextError :: DecodeFailed ( err) ) ;
@@ -182,7 +181,7 @@ impl LlamaSession {
182
181
last_batch_size = sequence. len ( ) ;
183
182
}
184
183
185
- block_on ( self . inner . tokens . write ( ) ) . extend_from_slice ( tokens) ;
184
+ self . inner . tokens . write ( ) . unwrap ( ) . extend_from_slice ( tokens) ;
186
185
187
186
self . inner
188
187
. last_batch_size
@@ -236,34 +235,13 @@ impl LlamaSession {
236
235
. unwrap ( )
237
236
}
238
237
239
- /// Starts generating tokens at the end of the context using llama.cpp's built-in Beam search.
240
- /// TODO fix: beam search keeps going even after it should have ended
238
+ /// Starts generating tokens at the end of the context using a greedy
239
+ /// sampler
241
240
pub fn start_completing ( & mut self ) -> CompletionHandle {
242
- let ( tx, rx) = unbounded_channel ( ) ;
243
- let history_size = self . context_size ( ) ;
244
- let session = self . clone ( ) ;
245
-
246
- info ! ( "Generating completions with {history_size} tokens of history" ) ;
247
-
248
- thread:: spawn ( move || unsafe {
249
- let state = Box :: new ( detail:: BeamSearchState { tx } ) ;
250
- // SAFETY: `state_ptr` is converted back to a [`Box`] and freed in [`detail::llama_beam_search_callback`]
251
- let state_ptr = Box :: into_raw ( state) ;
252
-
253
- llama_beam_search (
254
- block_on ( session. inner . ctx . lock ( ) ) . ptr ,
255
- Some ( detail:: llama_beam_search_callback) ,
256
- state_ptr as * mut _ as * mut c_void ,
257
- 1 ,
258
- history_size as i32 ,
259
- 32_768 ,
260
- ) ;
261
- } ) ;
262
-
263
- CompletionHandle {
264
- rx,
265
- model : self . model ( ) ,
266
- }
241
+ self . start_completing_with (
242
+ StandardSampler :: new_greedy ( ) ,
243
+ self . params ( ) . n_ctx as usize - self . context_size ( ) ,
244
+ )
267
245
}
268
246
269
247
/// Start completion.
@@ -282,10 +260,10 @@ impl LlamaSession {
282
260
info ! ( "Generating completions with {history_size} tokens of history" ) ;
283
261
284
262
thread:: spawn ( move || {
285
- let context = block_on ( session. inner . ctx . lock ( ) ) ;
263
+ let context = session. inner . ctx . lock ( ) . unwrap ( ) ;
286
264
let vocab = session. model ( ) . vocabulary_size ( ) ;
287
265
let end_of_stream = session. model ( ) . eos ( ) ;
288
- let mut token_buf = block_on ( session. inner . tokens . write ( ) ) ;
266
+ let mut token_buf = session. inner . tokens . write ( ) . unwrap ( ) ;
289
267
let mut count = 0 ;
290
268
let mut batch = Batch :: new ( 1 , 0 , 1 ) ;
291
269
let mut i = session. inner . last_batch_size . load ( Ordering :: SeqCst ) ;
@@ -366,12 +344,12 @@ impl LlamaSession {
366
344
367
345
/// Returns the number of tokens currently in this session's context
368
346
pub fn context_size ( & self ) -> usize {
369
- block_on ( self . inner . tokens . read ( ) ) . len ( )
347
+ self . inner . tokens . read ( ) . unwrap ( ) . len ( )
370
348
}
371
349
372
350
/// Returns the list of tokens in the current context
373
351
pub fn context ( & self ) -> Vec < Token > {
374
- block_on ( self . inner . tokens . read ( ) ) . clone ( )
352
+ self . inner . tokens . read ( ) . unwrap ( ) . clone ( )
375
353
}
376
354
377
355
/// Removes all tokens within the given range without perform
C95D
ing any prompt
@@ -393,12 +371,12 @@ impl LlamaSession {
393
371
Bound :: Unbounded => -1 ,
394
372
} ;
395
373
396
- let context = block_on ( self . inner . ctx . lock ( ) ) ;
374
+ let context = self . inner . ctx . lock ( ) . unwrap ( ) ;
397
375
398
376
// -1 here to match all sequences
399
377
unsafe { llama_kv_cache_seq_rm ( context. ptr , -1 , start_bound, end_bound) }
400
378
401
- block_on ( self . inner . tokens . write ( ) ) . drain ( range) ;
379
+ self . inner . tokens . write ( ) . unwrap ( ) . drain ( range) ;
402
380
}
403
381
404
382
/// Removes all but the first `n_tokens` tokens from the context.
@@ -415,7 +393,7 @@ impl LlamaSession {
415
393
new_tokens : impl AsRef < [ Token ] > ,
416
394
) -> Result < ( ) , LlamaContextError > {
417
395
let new_tokens = new_tokens. as_ref ( ) ;
418
- let old_tokens = block_on ( self . inner . tokens . read ( ) ) ;
396
+ let old_tokens = self . inner . tokens . read ( ) . unwrap ( ) ;
419
397
420
398
let shared_prefix = old_tokens
421
399
. iter ( )
@@ -480,7 +458,7 @@ impl LlamaSession {
480
458
/// This differs from [`LlamaSession::clone`] in that [`LlamaSession::clone`] creates a new
481
459
/// reference to the same underlying [`LlamaSession`].
482
460
pub fn deep_copy ( & self ) -> Result < LlamaSession , LlamaContextError > {
483
- let ctx = self . inner . ctx . blocking_lock ( ) ;
461
+ let ctx = self . inner . ctx . lock ( ) . unwrap ( ) ;
484
462
485
463
#[ allow( unused_mut) ]
486
464
let mut copy = self . model ( ) . create_session ( self . inner . params . clone ( ) ) ?;
@@ -496,13 +474,13 @@ impl LlamaSession {
496
474
let copy_size = llama_copy_state_data ( ctx. ptr , buf. as_mut_ptr ( ) ) ;
497
475
assert ! ( copy_size <= size) ;
498
476
let set_size =
499
- llama_set_state_data ( copy. inner . ctx . blocking_lock ( ) . ptr , buf. as_mut_ptr ( ) ) ;
477
+ llama_set_state_data ( copy. inner . ctx . lock ( ) . unwrap ( ) . ptr , buf. as_mut_ptr ( ) ) ;
500
478
assert_eq ! ( copy_size, set_size) ;
501
479
}
502
480
503
481
// NOTE: Any changes to the fields of a LlamaSession may require that
504
482
// those changes are mirrored here
505
- * block_on ( copy. inner . tokens . write ( ) ) = block_on ( self . inner . tokens . read ( ) ) . clone ( ) ;
483
+ * copy. inner . tokens . write ( ) . unwrap ( ) = self . inner . tokens . read ( ) . unwrap ( ) . clone ( ) ;
506
484
copy. inner . last_batch_size . store (
507
485
self . inner . last_batch_size . load ( Ordering :: SeqCst ) ,
508
486
Ordering :: SeqCst ,
@@ -512,16 +490,8 @@ impl LlamaSession {
512
490
}
513
491
514
492
/// Returns the maximum size in bytes this session is occupying in memory.
515
- ///
516
- /// This function may **NOT*** be called in async environments, for an async version see [`async_memory_size`].
517
493
pub fn memory_size ( & self ) -> usize {
518
- let ctx = self . inner . ctx . blocking_lock ( ) ;
519
- unsafe { llama_get_state_size ( ctx. ptr ) }
520
- }
521
-
522
- /// Asynchronously returns the maximum size in bytes this session is occupying in memory.
523
- pub async fn async_memory_size ( & self ) -> usize {
524
- let ctx = self . inner . ctx . lock ( ) . await ;
494
+ let ctx = self . inner . ctx . lock ( ) . unwrap ( ) ;
525
495
unsafe { llama_get_state_size ( ctx. ptr ) }
526
496
}
527
497
}
0 commit comments