8000 Add types for all low-level api functions · Dpaste20/llama-cpp-python@b6a9a0b · GitHub
[go: up one dir, main page]

Skip to content

Commit b6a9a0b

Browse files
committed
Add types for all low-level api functions
1 parent 5be0efa commit b6a9a0b

File tree

2 files changed

+62
-21
lines changed

2 files changed

+62
-21
lines changed

llama_cpp/llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __init__(
5252
self,
5353
eval_tokens: Deque[llama_cpp.llama_token],
5454
eval_logits: Deque[List[llama_cpp.c_float]],
55-
llama_state,
55+
llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8]
5656
llama_state_size: llama_cpp.c_size_t,
5757
):
5858
self.eval_tokens = eval_tokens

llama_cpp/llama_cpp.py

Lines changed: 61 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818

1919
# Load the library
20-
def _load_shared_library(lib_base_name):
20+
def _load_shared_library(lib_base_name: str):
2121
# Determine the file extension based on the platform
2222
if sys.platform.startswith("linux"):
2323
lib_ext = ".so"
@@ -252,7 +252,9 @@ def llama_get_state_size(ctx: llama_context_p) -> c_size_t:
252252
# Copies the state to the specified destination address.
253253
# Destination needs to have allocated enough memory.
254254
# Returns the number of bytes copied
255-
def llama_copy_state_data(ctx: llama_context_p, dest) -> c_size_t:
255+
def llama_copy_state_data(
256+
ctx: llama_context_p, dest # type: Array[c_uint8]
257+
) -> c_size_t:
256258
return _lib.llama_copy_state_data(ctx, dest)
257259

258260

@@ -262,7 +264,9 @@ def llama_copy_state_data(ctx: llama_context_p, dest) -> c_size_t:
262264

263265
# Set the state reading from the specified address
264266
# Returns the number of bytes read
265-
def llama_set_state_data(ctx: llama_context_p, src) -> c_size_t:
267+
def llama_set_state_data(
268+
ctx: llama_context_p, src # type: Array[c_uint8]
269+
) -> c_size_t:
266270
return _lib.llama_set_state_data(ctx, src)
267271

268272

@@ -274,9 +278,9 @@ def llama_set_state_data(ctx: llama_context_p, src) -> c_size_t:
274278
def llama_load_session_file(
275279
ctx: llama_context_p,
276280
path_session: bytes,
277-
tokens_out,
281+
tokens_out, # type: Array[llama_token]
278282
n_token_capacity: c_size_t,
279-
n_token_count_out,
283+
n_token_count_out, # type: Array[c_size_t]
280284
) -> c_size_t:
281285
return _lib.llama_load_session_file(
282286
ctx, path_session, tokens_out, n_token_capacity, n_token_count_out
@@ -294,7 +298,10 @@ def llama_load_session_file(
294298

295299

296300
def llama_save_session_file(
297-
ctx: llama_context_p, path_session: bytes, tokens, n_token_count: c_size_t
301+
ctx: llama_context_p,
302+
path_session: bytes,
303+
tokens, # type: Array[llama_token]
304+
n_token_count: c_size_t,
298305
) -> c_size_t:
299306
return _lib.llama_save_session_file(ctx, path_session, tokens, n_token_count)
300307

@@ -433,8 +440,8 @@ def llama_token_nl() -> llama_token:
433440
# @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
434441
def llama_sample_repetition_penalty(
435442
ctx: llama_context_p,
436-
candidates,
437-
last_tokens_data,
443+
candidates, # type: Array[llama_token_data]
444+
last_tokens_data, # type: Array[llama_token]
438445
last_tokens_size: c_int,
439446
penalty: c_float,
440447
):
@@ -456,8 +463,8 @@ def llama_sample_repetition_penalty(
456463
# @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
457464
def llama_sample_frequency_and_presence_penalties(
458465
ctx: llama_context_p,
459-
candidates,
460-
last_tokens_data,
466+
candidates, # type: Array[llama_token_data]
467+
last_tokens_data, # type: Array[llama_token]
461468
last_tokens_size: c_int,
462469
alpha_frequency: c_float,
463470
alpha_presence: c_float,
@@ -484,7 +491,10 @@ def llama_sample_frequency_and_presence_penalties(
484491

485492

486493
# @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
487-
def llama_sample_softmax(ctx: llama_context_p, candidates):
494+
def llama_sample_softmax(
495+
ctx: llama_context_p,
496+
candidates # type: Array[llama_token_data]
497+
):
488498
return _lib.llama_sample_softmax(ctx, candidates)
489499

490500

@@ -497,7 +507,10 @@ def llama_sample_softmax(ctx: llama_context_p, candidates):
497507

498508
# @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
499509
def llama_sample_top_k(
500-
ctx: llama_context_p, candidates, k: c_int, min_keep: c_size_t = c_size_t(1)
510+
ctx: llama_context_p,
511+
candidates, # type: Array[llama_token_data]
512+
k: c_int,
513+
min_keep: c_size_t = c_size_t(1)
501514
):
502515
return _lib.llama_sample_top_k(ctx, candidates, k, min_keep)
503516

@@ -513,7 +526,10 @@ def llama_sample_top_k(
513526

514527
# @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
515528
def llama_sample_top_p(
516-
ctx: llama_context_p, candidates, p: c_float, min_keep: c_size_t = c_size_t(1)
529+
ctx: llama_context_p,
530+
candidates, # type: Array[llama_token_data]
531+
p: c_float,
532+
min_keep: c_size_t = c_size_t(1)
517533
):
518534
return _lib.llama_sample_top_p(ctx, candidates, p, min_keep)
519535

@@ -529,7 +545,10 @@ def llama_sample_top_p(
529545

530546
# @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
531547
def llama_sample_tail_free(
532-
ctx: llama_context_p, candidates, z: c_float, min_keep: c_size_t = c_size_t(1)
548+
ctx: llama_context_p,
549+
candidates, # type: Array[llama_token_data]
550+
z: c_float,
551+
min_keep: c_size_t = c_size_t(1)
533552
):
534553
return _lib.llama_sample_tail_free(ctx, candidates, z, min_keep)
535554

@@ -545,7 +564,10 @@ def llama_sample_tail_free(
545564

546565
# @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
547566
def llama_sample_typical(
548-
ctx: llama_context_p, candidates, p: c_float, min_keep: c_size_t = c_size_t(1)
567+
ctx: llama_context_p,
568+
candidates, # type: Array[llama_token_data]
569+
p: c_float,
570+
min_keep: c_size_t = c_size_t(1)
549571
):
550572
return _lib.llama_sample_typical(ctx, candidates, p, min_keep)
551573

@@ -559,7 +581,11 @@ def llama_sample_typical(
559581
_lib.llama_sample_typical.restype = None
560582

561583

562-
def llama_sample_temperature(ctx: llama_context_p, candidates, temp: c_float):
584+
def llama_sample_temperature(
585+
ctx: llama_context_p,
586+
candidates, # type: Array[llama_token_data]
587+
temp: c_float
588+
):
563589
return _lib.llama_sample_temperature(ctx, candidates, temp)
564590

565591

@@ -578,7 +604,12 @@ def llama_sample_temperature(ctx: llama_context_p, candidates, temp: c_float):
578604
# @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
579605
# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
580606
def llama_sample_token_mirostat(
581-
ctx: llama_context_p, candidates, tau: c_float, eta: c_float, m: c_int, mu
607+
ctx: llama_context_p,
608+
candidates, # type: Array[llama_token_data]
609+
tau: c_float,
610+
eta: c_float,
611+
m: c_int,
612+
mu # type: Array[c_float]
582613
) -> llama_token:
583614
return _lib.llama_sample_token_mirostat(ctx, candidates, tau, eta, m, mu)
584615

@@ -600,7 +631,11 @@ def llama_sample_token_mirostat(
600631
# @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
601632
# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
602633
def llama_sample_token_mirostat_v2(
603-
ctx: llama_context_p, candidates, tau: c_float, eta: c_float, mu
634+
ctx: llama_context_p,
635+
candidates, # type: Array[llama_token_data]
636+
tau: c_float,
637+
eta: c_float,
638+
mu # type: Array[c_float]
604639
) -> llama_token:
605640
return _lib.llama_sample_token_mirostat_v2(ctx, candidates, tau, eta, mu)
606641

@@ -616,7 +651,10 @@ def llama_sample_token_mirostat_v2(
616651

617652

618653
# @details Selects the token with the highest probability.
619-
def llama_sample_token_greedy(ctx: llama_context_p, candidates) -> llama_token:
654+
def llama_sample_token_greedy(
655+
ctx: llama_context_p,
656+
candidates # type: Array[llama_token_data]
657+
) -> llama_token:
620658
return _lib.llama_sample_token_greedy(ctx, candidates)
621659

622660

@@ -628,7 +666,10 @@ def llama_sample_token_greedy(ctx: llama_context_p, candidates) -> llama_token:
628666

629667

630668
# @details Randomly selects a token from the candidates based on their probabilities.
631-
def llama_sample_token(ctx: llama_context_p, candidates) -> llama_token:
669+
def llama_sample_token(
670+
ctx: llama_context_p,
671+
candidates # type: Array[llama_token_data]
672+
) -> llama_token:
632673
return _lib.llama_sample_token(ctx, candidates)
633674

634675

0 commit comments

Comments
 (0)
0