8000 docs: Improve low-level docstrings · qeleb/llama-cpp-python@396dbf0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 396dbf0

Browse files
committed
docs: Improve low-level docstrings
1 parent 9c68b18 commit 396dbf0

File tree

1 file changed

+95
-2
lines changed

1 file changed

+95
-2
lines changed

llama_cpp/llama_cpp.py

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,12 @@ def _load_shared_library(lib_base_name: str):
212212
# float p; // probability of the token
213213
# } llama_token_data;
214214
class llama_token_data(Structure):
215+
"""Used to store token data
216+
217+
Attributes:
218+
id (llama_token): token id
219+
logit (float): log-odds of the token
220+
p (float): probability of the token"""
215221
_fields_ = [
216222
("id", llama_token),
217223
("logit", c_float),
@@ -228,6 +234,12 @@ class llama_token_data(Structure):
228234
# bool sorted;
229235
# } llama_token_data_array;
230236
class llama_token_data_array(Structure):
237+
"""Used to sample tokens given logits
238+
239+
Attributes:
240+
data (ctypes.Array[llama_token_data]): token data
241+
size (int): size of the array
242+
sorted (bool): whether the array is sorted"""
231243
_fields_ = [
232244
("data", llama_token_data_p),
233245
("size", c_size_t),
@@ -282,8 +294,7 @@ class llama_batch(Structure):
282294
token (ctypes.Array[llama_token]): the token ids of the input (used when embd is NULL)
283295
embd (ctypes.Array[ctypes.c_float]): token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
284296
pos (ctypes.Array[ctypes.Array[llama_pos]]): the positions of the respective token in the sequence
285-
seq_id (ctypes.Array[ctypes.Array[llama_seq_id]]): the sequence to which the respective token belongs
286-
"""
297+
seq_id (ctypes.Array[ctypes.Array[llama_seq_id]]): the sequence to which the respective token belongs"""
287298

288299
_fields_ = [
289300
("n_tokens", c_int32),
@@ -316,6 +327,17 @@ class llama_batch(Structure):
316327
# bool use_mlock; // force system to keep model in RAM
317328
# };
318329
class llama_model_params(Structure):
330+
"""Parameters for llama_model
331+
332+
Attributes:
333+
n_gpu_layers (int): number of layers to store in VRAM
334+
main_gpu (int): the GPU that is used for scratch and small tensors
335+
tensor_split (ctypes.Array[ctypes.c_float]): how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
336+
progress_callback (llama_progress_callback): called with a progress value between 0 and 1, pass NULL to disable
337+
progress_callback_user_data (ctypes.c_void_p): context pointer passed to the progress callback
338+
vocab_only (bool): only load the vocabulary, no weights
339+
use_mmap (bool): use mmap if possible
340+
use_mlock (bool): force system to keep model in RAM"""
319341
_fields_ = [
320342
("n_gpu_layers", c_int32),
321343
("main_gpu", c_int32),
@@ -353,6 +375,26 @@ class llama_model_params(Structure):
353375
# bool embedding; // embedding mode only
354376
# };
355377
class llama_context_params(Structure):
378+
"""Parameters for llama_context
379+
380+
Attributes:
381+
seed (int): RNG seed, -1 for random
382+
n_ctx (int): text context, 0 = from model
383+
n_batch (int): prompt processing maximum batch size
384+
n_threads (int): number of threads to use for generation
385+
n_threads_batch (int): number of threads to use for batch processing
386+
rope_scaling_type (int): RoPE scaling type, from `enum llama_rope_scaling_type`
387+
rope_freq_base (float): RoPE base frequency, 0 = from model
388+
rope_freq_scale (float): RoPE frequency scaling factor, 0 = from model
389+
yarn_ext_factor (float): YaRN extrapolation mix factor, negative = from model
390+
yarn_attn_factor (float): YaRN magnitude scaling factor
391+
yarn_beta_fast (float): YaRN low correction dim
392+
yarn_beta_slow (float): YaRN high correction dim
393+
yarn_orig_ctx (int): YaRN original context size
394+
mul_mat_q (bool): if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
395+
f16_kv (bool): use fp16 for KV cache, fp32 otherwise
396+
logits_all (bool): the llama_eval() call computes all logits, not just the last one
397+
embedding (bool): embedding mode only"""
356398
_fields_ = [
357399
("seed", c_uint32),
358400
("n_ctx", c_uint32),
@@ -398,6 +440,15 @@ class llama_context_params(Structure):
398440
# bool pure; // disable k-quant mixtures and quantize all tensors to the same type
399441
# } llama_model_quantize_params;
400442
class llama_model_quantize_params(Structure):
443+
"""Parameters for llama_model_quantize
444+
445+
Attributes:
446+
nthread (int): number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
447+
ftype (int): quantize to this llama_ftype
448+
allow_requantize (bool): allow quantizing non-f32/f16 tensors
449+
quantize_output_tensor (bool): quantize output.weight
450+
only_copy (bool): only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
451+
pure (bool): disable k-quant mixtures and quantize all tensors to the same type"""
401452
_fields_ = [
402453
("nthread", c_int),
403454
("ftype", c_int),
@@ -489,6 +540,7 @@ class llama_timings(Structure):
489540
# // Helpers for getting default parameters
490541
# LLAMA_API struct llama_model_params llama_model_default_params(void);
491542
def llama_model_default_params() -> llama_model_params:
543+
"""Get default parameters for llama_model"""
492544
return _lib.llama_model_default_params()
493545

494546

@@ -498,6 +550,7 @@ def llama_model_default_params() -> llama_model_params:
498550

499551
# LLAMA_API struct llama_context_params llama_context_default_params(void);
500552
def llama_context_default_params() -> llama_context_params:
553+
"""Get default parameters for llama_context"""
501554
return _lib.llama_context_default_params()
502555

503556

@@ -507,6 +560,7 @@ def llama_context_default_params() -> llama_context_params:
507560

508561
# LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
509562
def llama_model_quantize_default_params() -> llama_model_quantize_params:
563+
"""Get default parameters for llama_model_quantize"""
510564
return _lib.llama_model_quantize_default_params()
511565

512566

@@ -1668,6 +1722,7 @@ def llama_grammar_init(
16681722
n_rules: Union[c_size_t, int],
16691723
start_rule_index: Union[c_size_t, int],
16701724
) -> llama_grammar_p:
1725+
"""Initialize a grammar from a set of rules."""
16711726
return _lib.llama_grammar_init(rules, n_rules, start_rule_index)
16721727

16731728

@@ -1681,6 +1736,7 @@ def llama_grammar_init(
16811736

16821737
# LLAMA_API void llama_grammar_free(struct llama_grammar * grammar);
16831738
def llama_grammar_free(grammar: llama_grammar_p):
1739+
"""Free a grammar."""
16841740
return _lib.llama_grammar_free(grammar)
16851741

16861742

@@ -1690,6 +1746,7 @@ def llama_grammar_free(grammar: llama_grammar_p):
16901746

16911747
# LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
16921748
def llama_grammar_copy(grammar: llama_grammar_p) -> llama_grammar_p:
1749+
"""Copy a grammar."""
16931750
return _lib.llama_grammar_copy(grammar)
16941751

16951752

@@ -1939,6 +1996,11 @@ def llama_sample_temp(
19391996
candidates, # type: _Pointer[llama_token_data_array]
19401997
temp: Union[c_float, float],
19411998
):
1999+
"""Temperature sampling described in academic paper "Generating Long Sequences with Sparse Transformers" https://arxiv.org/abs/1904.10509
2000+
2001+
Parameters:
2002+
candidates: A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
2003+
temp: The temperature value to use for the sampling. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text."""
19422004
return _lib.llama_sample_temp(ctx, candidates, temp)
19432005

19442006

@@ -1960,6 +2022,7 @@ def llama_sample_temperature(
19602022
candidates, # type: _Pointer[llama_token_data_array]
19612023
temp: Union[c_float, float],
19622024
):
2025+
"""use llama_sample_temp instead"""
19632026
return _lib.llama_sample_temperature(ctx, candidates, temp)
19642027

19652028

@@ -1981,6 +2044,11 @@ def llama_sample_grammar(
19812044
candidates, # type: _Pointer[llama_token_data_array]
19822045
grammar, # type: llama_grammar_p
19832046
):
2047+
"""Apply constraints from grammar
2048+
2049+
Parameters:
2050+
candidates: A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
2051+
grammar: A grammar object containing the rules and constraints to apply to the generated text."""
19842052
return _lib.llama_sample_grammar(ctx, candidates, grammar)
19852053

19862054

@@ -2013,6 +2081,14 @@ def llama_sample_token_mirostat(
20132081
m: Union[c_int, int],
20142082
mu, # type: _Pointer[c_float]
20152083
) -> int:
2084+
"""Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
2085+
2086+
Parameters:
2087+
candidates: A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
2088+
tau: The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
2089+
eta: The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
2090+
m: The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
2091+
mu: Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal."""
20162092
return _lib.llama_sample_token_mirostat(ctx, candidates, tau, eta, m, mu)
20172093

20182094

@@ -2045,6 +2121,13 @@ def llama_sample_token_mirostat_v2(
20452121
eta: Union[c_float, float],
20462122
mu, # type: _Pointer[c_float]
20472123
) -> int:
2124+
"""Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
2125+
2126+
Parameters:
2127+
candidates: A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
2128+
tau: The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
2129+
eta: The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
2130+
mu: Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal."""
20482131
return _lib.llama_sample_token_mirostat_v2(ctx, candidates, tau, eta, mu)
20492132

20502133

@@ -2067,6 +2150,7 @@ def llama_sample_token_greedy(
20672150
ctx: llama_context_p,
20682151
candidates, # type: _Pointer[llama_token_data_array]
20692152
) -> int:
2153+
"""Selects the token with the highest probability."""
20702154
return _lib.llama_sample_token_greedy(ctx, candidates)
20712155

20722156

@@ -2085,6 +2169,7 @@ def llama_sample_token(
20852169
ctx: llama_context_p,
20862170
candidates, # type: _Pointer[llama_token_data_array]
20872171
) -> int:
2172+
"""Randomly selects a token from the candidates based on their probabilities."""
20882173
return _lib.llama_sample_token(ctx, candidates)
20892174

20902175

@@ -2105,6 +2190,7 @@ def llama_grammar_accept_token(
21052190
grammar: llama_grammar_p,
21062191
token: Union[llama_token, int],
21072192
) -> None:
2193+
"""Accepts the sampled token into the grammar"""
21082194
_lib.llama_grammar_accept_token(ctx, grammar, token)
21092195

21102196

@@ -2207,6 +2293,7 @@ def llama_beam_search(
22072293

22082294
# LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
22092295
def llama_get_timings(ctx: llama_context_p) -> llama_timings:
2296+
"""Get performance information"""
22102297
return _lib.llama_get_timings(ctx)
22112298

22122299

@@ -2216,6 +2303,7 @@ def llama_get_timings(ctx: llama_context_p) -> llama_timings:
22162303

22172304
# LLAMA_API void llama_print_timings(struct llama_context * ctx);
22182305
def llama_print_timings(ctx: llama_context_p):
2306+
"""Print performance information"""
22192307
_lib.llama_print_timings(ctx)
22202308

22212309

@@ -2225,6 +2313,7 @@ def llama_print_timings(ctx: llama_context_p):
22252313

22262314
# LLAMA_API void llama_reset_timings(struct llama_context * ctx);
22272315
def llama_reset_timings(ctx: llama_context_p):
2316+
"""Reset performance information"""
22282317
_lib.llama_reset_timings(ctx)
22292318

22302319

@@ -2235,6 +2324,7 @@ def llama_reset_timings(ctx: llama_context_p):
22352324
# Print system information
22362325
# LLAMA_API const char * llama_print_system_info(void);
22372326
def llama_print_system_info() -> bytes:
2327+
"""Print system information"""
22382328
return _lib.llama_print_system_info()
22392329

22402330

@@ -2249,6 +2339,9 @@ def llama_print_system_info() -> bytes:
22492339
def llama_log_set(
22502340
log_callback: "ctypes._FuncPointer", user_data: c_void_p # type: ignore
22512341
):
2342+
"""Set callback for all future logging events.
2343+
2344+
If this is not called, or NULL is supplied, everything is output on stderr."""
22522345
return _lib.llama_log_set(log_callback, user_data)
22532346

22542347

0 commit comments

Comments
 (0)
0