8000 Fix llama_cpp types · Dpaste20/llama-cpp-python@5e7ddfc · GitHub
[go: up one dir, main page]

Skip to content

Commit 5e7ddfc

Browse files
committed
Fix llama_cpp types
1 parent b6a9a0b commit 5e7ddfc

File tree

1 file changed

+33
-41
lines changed

1 file changed

+33
-41
lines changed

llama_cpp/llama_cpp.py

Lines changed: 33 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
c_void_p,
99
c_bool,
1010
POINTER,
11+
_Pointer, # type: ignore
1112
Structure,
1213
Array,
1314
c_uint8,
@@ -252,9 +253,7 @@ def llama_get_state_size(ctx: llama_context_p) -> c_size_t:
252253
# Copies the state to the specified destination address.
253254
# Destination needs to have allocated enough memory.
254255
# Returns the number of bytes copied
255-
def llama_copy_state_data(
256-
ctx: llama_context_p, dest # type: Array[c_uint8]
257-
) -> c_size_t:
256+
def llama_copy_state_data(ctx: llama_context_p, dest: Array[c_uint8]) -> c_size_t:
258257
return _lib.llama_copy_state_data(ctx, dest)
259258

260259

@@ -278,9 +277,9 @@ def llama_set_state_data(
278277
def llama_load_session_file(
279278
ctx: llama_context_p,
280279
path_session: bytes,
281-
tokens_out, # type: Array[llama_token]
280+
tokens_out: Array[llama_token],
282281
n_token_capacity: c_size_t,
283-
n_token_count_out, # type: Array[c_size_t]
282+
n_token_count_out: _Pointer[c_size_t],
284283
) -> c_size_t:
285284
return _lib.llama_load_session_file(
286285
ctx, path_session, tokens_out, n_token_capacity, n_token_count_out
@@ -300,7 +299,7 @@ def llama_load_session_file(
300299
def llama_save_session_file(
301300
ctx: llama_context_p,
302301
path_session: bytes,
303-
tokens, # type: Array[llama_token]
302+
tokens: Array[llama_token],
304303
n_token_count: c_size_t,
305304
) -> c_size_t:
306305
return _lib.llama_save_session_file(ctx, path_session, tokens, n_token_count)
@@ -321,7 +320,7 @@ def llama_save_session_file(
321320
# Returns 0 on success
322321
def llama_eval(
323322
ctx: llama_context_p,
324-
tokens, # type: Array[llama_token]
323+
tokens: Array[llama_token],
325324
n_tokens: c_int,
326325
n_past: c_int,
327326
n_threads: c_int,
@@ -440,8 +439,8 @@ def llama_token_nl() -> llama_token:
440439
# @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
441440
def llama_sample_repetition_penalty(
442441
ctx: llama_context_p,
443-
candidates, # type: Array[llama_token_data]
444-
last_tokens_data, # type: Array[llama_token]
442+
candidates: _Pointer[llama_token_data],
443+
last_tokens_data: Array[llama_token],
445444
last_tokens_size: c_int,
446445
penalty: c_float,
447446
):
@@ -463,8 +462,8 @@ def llama_sample_repetition_penalty(
463462
# @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
464463
def llama_sample_frequency_and_presence_penalties(
465464
ctx: llama_context_p,
466-
candidates, # type: Array[llama_token_data]
467-
last_tokens_data, # type: Array[llama_token]
465+
candidates: _Pointer[llama_token_data],
466+
last_tokens_data: Array[llama_token],
468467
last_tokens_size: c_int,
469468
alpha_frequency: c_float,
470469
alpha_presence: c_float,
@@ -491,10 +490,7 @@ def llama_sample_frequency_and_presence_penalties(
491490

492491

493492
# @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
494-
def llama_sample_softmax(
495-
ctx: llama_context_p,
496-
candidates # type: Array[llama_token_data]
497-
):
493+
def llama_sample_softmax(ctx: llama_context_p, candidates: _Pointer[llama_token_data]):
498494
return _lib.llama_sample_softmax(ctx, candidates)
499495

500496

@@ -507,10 +503,10 @@ def llama_sample_softmax(
507503

508504
# @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
509505
def llama_sample_top_k(
510-
ctx: llama_context_p,
511-
candidates, # type: Array[llama_token_data]
512-
k: c_int,
513-
min_keep: c_size_t = c_size_t(1)
506+
ctx: llama_context_p,
507+
candidates: _Pointer[llama_token_data],
508+
k: c_int,
509+
min_keep: c_size_t = c_size_t(1),
514510
):
515511
return _lib.llama_sample_top_k(ctx, candidates, k, min_keep)
516512

@@ -526,10 +522,10 @@ def llama_sample_top_k(
526522

527523
# @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
528524
def llama_sample_top_p(
529-
ctx: llama_context_p,
530-
candidates, # type: Array[llama_token_data]
531-
p: c_float,
532-
min_keep: c_size_t = c_size_t(1)
525+
ctx: llama_context_p,
526+
candidates: _Pointer[llama_token_data],
527+
p: c_float,
528+
min_keep: c_size_t = c_size_t(1),
533529
):
534530
return _lib.llama_sample_top_p(ctx, candidates, p, min_keep)
535531

@@ -546,9 +542,9 @@ def llama_sample_top_p(
546542
# @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
547543
def llama_sample_tail_free(
548544
ctx: llama_context_p,
549-
candidates, # type: Array[llama_token_data]
545+
candidates: _Pointer[llama_token_data],
550546
z: c_float,
551-
min_keep: c_size_t = c_size_t(1)
547+
min_keep: c_size_t = c_size_t(1),
552548
):
553549
return _lib.llama_sample_tail_free(ctx, candidates, z, min_keep)
554550

@@ -565,9 +561,9 @@ def llama_sample_tail_free(
565561
# @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
566562
def llama_sample_typical(
567563
ctx: llama_context_p,
568-
candidates, # type: Array[llama_token_data]
569-
p: c_float,
570-
min_keep: c_size_t = c_size_t(1)
564+
candidates: _Pointer[llama_token_data],
565+
p: c_float,
566+
min_keep: c_size_t = c_size_t(1),
571567
):
572568
return _lib.llama_sample_typical(ctx, candidates, p, min_keep)
573569

@@ -582,9 +578,7 @@ def llama_sample_typical(
582578

583579

584580
def llama_sample_temperature(
585-
ctx: llama_context_p,
586-
candidates, # type: Array[llama_token_data]
587-
temp: c_float
581+
ctx: llama_context_p, candidates: _Pointer[llama_token_data], temp: c_float
588582
):
589583
return _lib.llama_sample_temperature(ctx, candidates, temp)
590584

@@ -605,11 +599,11 @@ def llama_sample_temperature(
605599
# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
606600
def llama_sample_token_mirostat(
607601
ctx: llama_context_p,
608-
candidates, # type: Array[llama_token_data]
602+
candidates: _Pointer[llama_token_data],
609603
tau: c_float,
610-
eta: c_float,
604+
eta: c_float,
611605
m: c_int,
612-
mu # type: Array[c_float]
606+
mu: _Pointer[c_float],
613607
) -> llama_token:
614608
return _lib.llama_sample_token_mirostat(ctx, candidates, tau, eta, m, mu)
615609

@@ -632,10 +626,10 @@ def llama_sample_token_mirostat(
632626
# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
633627
def llama_sample_token_mirostat_v2(
634628
ctx: llama_context_p,
635-
candidates, # type: Array[llama_token_data]
636-
tau: c_float,
629+
candidates: _Pointer[llama_token_data],
630+
tau: c_float,
637631
eta: c_float,
638-
mu # type: Array[c_float]
632+
mu: _Pointer[c_float],
639633
) -> llama_token:
640634
return _lib.llama_sample_token_mirostat_v2(ctx, candidates, tau, eta, mu)
641635

@@ -652,8 +646,7 @@ def llama_sample_token_mirostat_v2(
652646

653647
# @details Selects the token with the highest probability.
654648
def llama_sample_token_greedy(
655-
ctx: llama_context_p,
656-
candidates # type: Array[llama_token_data]
649+
ctx: llama_context_p, candidates: _Pointer[llama_token_data]
657650
) -> llama_token:
658651
return _lib.llama_sample_token_greedy(ctx, candidates)
659652

@@ -667,8 +660,7 @@ def llama_sample_token_greedy(
667660

668661
# @details Randomly selects a token from the candidates based on their probabilities.
669662
def llama_sample_token(
670-
ctx: llama_context_p,
671-
candidates # type: Array[llama_token_data]
663+
ctx: llama_context_p, candidates: _Pointer[llama_token_data]
672664
) -> llama_token:
673665
return _lib.llama_sample_token(ctx, candidates)
674666

0 commit comments

Comments
 (0)
0