@@ -103,8 +103,8 @@ def _load_shared_library(lib_base_name: str):
103
103
104
104
# define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
105
105
LLAMA_SESSION_MAGIC = LLAMA_FILE_MAGIC_GGSN
106
- # define LLAMA_SESSION_VERSION 2
107
- LLAMA_SESSION_VERSION = 2
106
+ # define LLAMA_SESSION_VERSION 3
107
+ LLAMA_SESSION_VERSION = 3
108
108
109
109
110
110
# struct llama_model;
@@ -309,6 +309,35 @@ class llama_batch(Structure):
309
309
("all_seq_id" , llama_seq_id ),
310
310
]
311
311
312
+ # enum llama_model_kv_override_type {
313
+ # LLAMA_KV_OVERRIDE_INT,
314
+ # LLAMA_KV_OVERRIDE_FLOAT,
315
+ # LLAMA_KV_OVERRIDE_BOOL,
316
+ # };
317
+ class llama_model_kv_override_type (Structure ):
318
+ _fields_ = [
319
+ ("LLAMA_KV_OVERRIDE_INT" , c_int ),
320
+ ("LLAMA_KV_OVERRIDE_FLOAT" , c_int ),
321
+ ("LLAMA_KV_OVERRIDE_BOOL" , c_int ),
322
+ ]
323
+
324
+ # struct llama_model_kv_override {
325
+ # char key[128];
326
+ # enum llama_model_kv_override_type tag;
327
+ # union {
328
+ # int64_t int_value;
329
+ # double float_value;
330
+ # bool bool_value;
331
+ # };
332
+ # };
333
+ class llama_model_kv_override (Structure ):
334
+ _fields_ = [
335
+ ("key" , ctypes .c_char * 128 ),
336
+ ("tag" , llama_model_kv_override_type ),
337
+ ("int_value" , ctypes .c_int64 ),
338
+ ("float_value" , c_double ),
339
+ ("bool_value" , c_bool ),
340
+ ]
312
341
313
342
# struct llama_model_params {
314
343
# int32_t n_gpu_layers; // number of layers to store in VRAM
@@ -320,6 +349,8 @@ class llama_batch(Structure):
320
349
# // context pointer passed to the progress callback
321
350
# void * progress_callback_user_data;
322
351
352
+ # // override key-value pairs of the model meta data
353
+ # const struct llama_model_kv_override * kv_overrides;
323
354
324
355
# // Keep the booleans together to avoid misalignment during copy-by-value.
325
356
# bool vocab_only; // only load the vocabulary, no weights
@@ -335,6 +366,7 @@ class llama_model_params(Structure):
335
366
tensor_split (ctypes.Array[ctypes.c_float]): how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
336
367
progress_callback (llama_progress_callback): called with a progress value between 0 and 1, pass NULL to disable
337
368
progress_callback_user_data (ctypes.c_void_p): context pointer passed to the progress callback
369
+ kv_overrides (ctypes.Array[llama_model_kv_override]): override key-value pairs of the model meta data
338
370
vocab_only (bool): only load the vocabulary, no weights
339
371
use_mmap (bool): use mmap if possible
340
372
use_mlock (bool): force system to keep model in RAM"""
@@ -344,6 +376,7 @@ class llama_model_params(Structure):
344
376
("tensor_split" , c_float_p ),
345
377
("progress_callback" , llama_progress_callback ),
346
378
("progress_callback_user_data" , c_void_p ),
379
+ ("kv_overrides" , POINTER (llama_model_kv_override )),
347
380
("vocab_only" , c_bool ),
348
381
("use_mmap" , c_bool ),
349
382
("use_mlock" , c_bool ),
@@ -367,12 +400,14 @@ class llama_model_params(Structure):
367
400
# float yarn_beta_slow; // YaRN high correction dim
368
401
# uint32_t yarn_orig_ctx; // YaRN original context size
369
402
403
+ # enum ggml_type type_k; // data type for K cache
404
+ # enum ggml_type type_v; // data type for V cache
370
405
371
406
# // Keep the booleans together to avoid misalignment during copy-by-value.
372
- # bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
373
- # bool f16_kv ; // use fp16 for KV cache, fp32 otherwise
374
- # bool logits_all; // the llama_eval() call computes all logits, not just the last one
375
- # bool embedding; // embedding mode only
407
+ # bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
408
+ # bool logits_all ; // the llama_eval() call computes all logits, not just the last one
409
+ # bool embedding; // embedding mode only
410
+ # bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
376
411
# };
377
412
class llama_context_params (Structure ):
378
413
"""Parameters for llama_context
@@ -391,6 +426,8 @@ class llama_context_params(Structure):
391
426
yarn_beta_fast (float): YaRN low correction dim
392
427
yarn_beta_slow (float): YaRN high correction dim
393
428
yarn_orig_ctx (int): YaRN original context size
429
+ type_k (int): data type for K cache
430
+ type_v (int): data type for V cache
394
431
mul_mat_q (bool): if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
395
432
f16_kv (bool): use fp16 for KV cache, fp32 otherwise
396
433
logits_all (bool): the llama_eval() call computes all logits, not just the last one
@@ -409,6 +446,8 @@ class llama_context_params(Structure):
409
446
("yarn_beta_fast" , c_float ),
410
447
("yarn_beta_slow" , c_float ),
411
448
("yarn_orig_ctx" , c_uint32 ),
449
+ ("type_k" , c_int ),
450
+ ("type_v" , c_int ),
412
451
("mul_mat_q" , c_bool ),
413
452
("f16_kv" , c_bool ),
414
453
("logits_all" , c_bool ),
0 commit comments