17
17
18
18
19
19
# Load the library
20
- def _load_shared_library (lib_base_name ):
20
+ def _load_shared_library (lib_base_name : str ):
21
21
# Determine the file extension based on the platform
22
22
if sys .platform .startswith ("linux" ):
23
23
lib_ext = ".so"
@@ -252,7 +252,9 @@ def llama_get_state_size(ctx: llama_context_p) -> c_size_t:
252
252
# Copies the state to the specified destination address.
253
253
# Destination needs to have allocated enough memory.
254
254
# Returns the number of bytes copied
255
- def llama_copy_state_data (ctx : llama_context_p , dest ) -> c_size_t :
255
+ def llama_copy_state_data (
256
+ ctx : llama_context_p , dest # type: Array[c_uint8]
257
+ ) -> c_size_t :
256
258
return _lib .llama_copy_state_data (ctx , dest )
257
259
258
260
@@ -262,7 +264,9 @@ def llama_copy_state_data(ctx: llama_context_p, dest) -> c_size_t:
262
264
263
265
# Set the state reading from the specified address
264
266
# Returns the number of bytes read
265
- def llama_set_state_data (ctx : llama_context_p , src ) -> c_size_t :
267
+ def llama_set_state_data (
268
+ ctx : llama_context_p , src # type: Array[c_uint8]
269
+ ) -> c_size_t :
266
270
return _lib .llama_set_state_data (ctx , src )
267
271
268
272
@@ -274,9 +278,9 @@ def llama_set_state_data(ctx: llama_context_p, src) -> c_size_t:
274
278
def llama_load_session_file (
275
279
ctx : llama_context_p ,
276
280
path_session : bytes ,
277
- tokens_out ,
281
+ tokens_out , # type: Array[llama_token]
278
282
n_token_capacity : c_size_t ,
279
- n_token_count_out ,
283
+ n_token_count_out , # type: Array[c_size_t]
280
284
) -> c_size_t :
281
285
return _lib .llama_load_session_file (
282
286
ctx , path_session , tokens_out , n_token_capacity , n_token_count_out
@@ -294,7 +298,10 @@ def llama_load_session_file(
294
298
295
299
296
300
def llama_save_session_file (
297
- ctx : llama_context_p , path_session : bytes , tokens , n_token_count : c_size_t
301
+ ctx : llama_context_p ,
302
+ path_session : bytes ,
303
+ tokens , # type: Array[llama_token]
304
+ n_token_count : c_size_t ,
298
305
) -> c_size_t :
299
306
return _lib .llama_save_session_file (ctx , path_session , tokens , n_token_count )
300
307
@@ -433,8 +440,8 @@ def llama_token_nl() -> llama_token:
433
440
# @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
434
441
def llama_sample_repetition_penalty (
435
442
ctx : llama_context_p ,
436
- candidates ,
437
- last_tokens_data ,
443
+ candidates , # type: Array[llama_token_data]
444
+ last_tokens_data , # type: Array[llama_token]
438
445
last_tokens_size : c_int ,
439
446
penalty : c_float ,
440
447
):
@@ -456,8 +463,8 @@ def llama_sample_repetition_penalty(
456
463
# @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
457
464
def llama_sample_frequency_and_presence_penalties (
458
465
ctx : llama_context_p ,
459
- candidates ,
460
- last_tokens_data ,
466
+ candidates , # type: Array[llama_token_data]
467
+ last_tokens_data , # type: Array[llama_token]
461
468
last_tokens_size : c_int ,
462
469
alpha_frequency : c_float ,
463
470
alpha_presence : c_float ,
@@ -484,7 +491,10 @@ def llama_sample_frequency_and_presence_penalties(
484
491
485
492
486
493
# @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
487
- def llama_sample_softmax (ctx : llama_context_p , candidates ):
494
+ def llama_sample_softmax (
495
+ ctx : llama_context_p ,
496
+ candidates # type: Array[llama_token_data]
497
+ ):
488
498
return _lib .llama_sample_softmax (ctx , candidates )
489
499
490
500
@@ -497,7 +507,10 @@ def llama_sample_softmax(ctx: llama_context_p, candidates):
497
507
498
508
# @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
499
509
def llama_sample_top_k (
500
- ctx : llama_context_p , candidates , k : c_int , min_keep : c_size_t = c_size_t (1 )
510
+ ctx : llama_context_p ,
511
+ candidates , # type: Array[llama_token_data]
512
+ k : c_int ,
513
+ min_keep : c_size_t = c_size_t (1 )
501
514
):
502
515
return _lib .llama_sample_top_k (ctx , candidates , k , min_keep )
503
516
@@ -513,7 +526,10 @@ def llama_sample_top_k(
513
526
514
527
# @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
515
528
def llama_sample_top_p (
516
- ctx : llama_context_p , candidates , p : c_float , min_keep : c_size_t = c_size_t (1 )
529
+ ctx : llama_context_p ,
530
+ candidates , # type: Array[llama_token_data]
531
+ p : c_float ,
532
+ min_keep : c_size_t = c_size_t (1 )
517
533
):
518
534
return _lib .llama_sample_top_p (ctx , candidates , p , min_keep )
519
535
@@ -529,7 +545,10 @@ def llama_sample_top_p(
529
545
530
546
# @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
531
547
def llama_sample_tail_free (
532
- ctx : llama_context_p , candidates , z : c_float , min_keep : c_size_t = c_size_t (1 )
548
+ ctx : llama_context_p ,
549
+ candidates , # type: Array[llama_token_data]
550
+ z : c_float ,
551
+ min_keep : c_size_t = c_size_t (1 )
533
552
):
534
553
return _lib .llama_sample_tail_free (ctx , candidates , z , min_keep )
535
554
@@ -545,7 +564,10 @@ def llama_sample_tail_free(
545
564
546
565
# @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
547
566
def llama_sample_typical (
548
- ctx : llama_context_p , candidates , p : c_float , min_keep : c_size_t = c_size_t (1 )
567
+ ctx : llama_context_p ,
568
+ candidates , # type: Array[llama_token_data]
569
+ p : c_float ,
570
+ min_keep : c_size_t = c_size_t (1 )
549
571
):
550
572
return _lib .llama_sample_typical (ctx , candidates , p , min_keep )
551
573
@@ -559,7 +581,11 @@ def llama_sample_typical(
559
581
_lib .llama_sample_typical .restype = None
560
582
561
583
562
- def llama_sample_temperature (ctx : llama_context_p , candidates , temp : c_float ):
584
+ def llama_sample_temperature (
585
+ ctx : llama_context_p ,
586
+ candidates , # type: Array[llama_token_data]
587
+ temp : c_float
588
+ ):
563
589
return _lib .llama_sample_temperature (ctx , candidates , temp )
564
590
565
591
@@ -578,7 +604,12 @@ def llama_sample_temperature(ctx: llama_context_p, candidates, temp: c_float):
578
604
# @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
579
605
# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
580
606
def llama_sample_token_mirostat (
581
- ctx : llama_context_p , candidates , tau : c_float , eta : c_float , m : c_int , mu
607
+ ctx : llama_context_p ,
608
+ candidates , # type: Array[llama_token_data]
609
+ tau : c_float ,
610
+ eta : c_float ,
611
+ m : c_int ,
612
+ mu # type: Array[c_float]
582
613
) -> llama_token :
583
614
return _lib .llama_sample_token_mirostat (ctx , candidates , tau , eta , m , mu )
584
615
@@ -600,7 +631,11 @@ def llama_sample_token_mirostat(
600
631
# @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
601
632
# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
602
633
def llama_sample_token_mirostat_v2 (
603
- ctx : llama_context_p , candidates , tau : c_float , eta : c_float , mu
634
+ ctx : llama_context_p ,
635
+ candidates , # type: Array[llama_token_data]
636
+ tau : c_float ,
637
+ eta : c_float ,
638
+ mu # type: Array[c_float]
604
639
) -> llama_token :
605
640
return _lib .llama_sample_token_mirostat_v2 (ctx , candidates , tau , eta , mu )
606
641
@@ -616,7 +651,10 @@ def llama_sample_token_mirostat_v2(
616
651
617
652
618
653
# @details Selects the token with the highest probability.
619
- def llama_sample_token_greedy (ctx : llama_context_p , candidates ) -> llama_token :
654
+ def llama_sample_token_greedy (
655
+ ctx : llama_context_p ,
656
+ candidates # type: Array[llama_token_data]
657
+ ) -> llama_token :
620
658
return _lib .llama_sample_token_greedy (ctx , candidates )
621
659
622
660
@@ -628,7 +666,10 @@ def llama_sample_token_greedy(ctx: llama_context_p, candidates) -> llama_token:
628
666
629
667
630
668
# @details Randomly selects a token from the candidates based on their probabilities.
631
- def llama_sample_token (ctx : llama_context_p , candidates ) -> llama_token :
669
+ def llama_sample_token (
670
+ ctx : llama_context_p ,
671
+ candidates # type: Array[llama_token_data]
672
+ ) -> llama_token :
632
673
return _lib .llama_sample_token (ctx , candidates )
633
674
634
675
0 commit comments