@@ -148,6 +148,12 @@ def byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCDa
148
148
ctypes .c_bool , ctypes .c_void_p , ctypes .c_bool , ctypes .c_void_p
149
149
)
150
150
151
+ # // Abort callback
152
+ # // If not NULL, called before ggml computation
153
+ # // If it returns true, the computation is aborted
154
+ # typedef bool (*ggml_abort_callback)(void * data);
155
+ ggml_abort_callback = ctypes .CFUNCTYPE (ctypes .c_bool , ctypes .c_void_p )
156
+
151
157
# llama.h bindings
152
158
153
159
_lib .llama_max_devices .argtypes = []
@@ -560,10 +566,16 @@ class llama_model_params(ctypes.Structure):
560
566
# enum ggml_type type_v; // data type for V cache
561
567
562
568
# // Keep the booleans together to avoid misalignment during copy-by-value.
563
- # bool logits_all; // the llama_eval () call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
569
+ # bool logits_all; // the llama_decode () call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
564
570
# bool embedding; // embedding mode only
565
571
# bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
566
572
# bool do_pooling; // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
573
+
574
+ # // Abort callback
575
+ # // if it returns true, execution of llama_decode() will be aborted
576
+ # // currently works only with CPU execution
577
+ # ggml_abort_callback abort_callback;
578
+ # void * abort_callback_data;
567
579
# };
568
580
class llama_context_params (ctypes .Structure ):
569
581
"""Parameters for llama_context
@@ -591,6 +603,8 @@ class llama_context_params(ctypes.Structure):
591
603
embedding (bool): embedding mode only
592
604
offload_kqv (bool): whether to offload the KQV ops (including the KV cache) to GPU
593
605
do_pooling (bool): whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
606
+ abort_callback (ggml_abort_callback): abort callback if it returns true, execution of llama_decode() will be aborted
607
+ abort_callback_data (ctypes.ctypes.c_void_p): data for abort_callback
594
608
"""
595
609
596
610
_fields_ = [
@@ -616,6 +630,8 @@ class llama_context_params(ctypes.Structure):
616
630
("embedding" , ctypes .c_bool ),
617
631
("offload_kqv" , ctypes .c_bool ),
618
632
("do_pooling" , ctypes .c_bool ),
633
+ ("abort_callback" , ggml_abort_callback ),
634
+ ("abort_callback_data" , ctypes .c_void_p ),
619
635
]
620
636
621
637
@@ -1703,8 +1719,24 @@ def llama_set_n_threads(
1703
1719
"""
1704
1720
...
1705
1721
1722
+ # // Set abort callback
1723
+ # LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
1724
+ @ctypes_function (
1725
+ "llama_set_abort_callback" ,
1726
+ [llama_context_p_ctypes , ggml_abort_callback , ctypes .c_void_p ],
1727
+ None ,
1728
+ )
1729
+ def llama_set_abort_callback (
1730
+ ctx : llama_context_p ,
1731
+ abort_callback : Callable [[ctypes .c_void_p ], None ],
1732
+ abort_callback_data : ctypes .c_void_p ,
1733
+ / ,
1734
+ ):
1735
+ """Set abort callback"""
1736
+ ...
1737
+
1706
1738
1707
- # // Token logits obtained from the last call to llama_eval ()
1739
+ # // Token logits obtained from the last call to llama_decode ()
1708
1740
# // The logits for the last token are stored in the last row
1709
1741
# // Logits for which llama_batch.logits[i] == 0 are undefined
1710
1742
# // Rows: n_tokens provided with llama_batch
0 commit comments