@@ -79,6 +79,7 @@ def __init__(
79
79
n_threads : Optional [int ] = None ,
80
80
n_threads_batch : Optional [int ] = None ,
81
81
rope_scaling_type : Optional [int ] = llama_cpp .LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED ,
82
+ pooling_type : int = llama_cpp .LLAMA_POOLING_TYPE_UNSPECIFIED ,
82
83
rope_freq_base : float = 0.0 ,
83
84
rope_freq_scale : float = 0.0 ,
84
85
yarn_ext_factor : float = - 1.0 ,
@@ -104,6 +105,9 @@ def __init__(
104
105
draft_model : Optional [LlamaDraftModel ] = None ,
105
106
# Tokenizer Override
106
107
tokenizer : Optional [BaseLlamaTokenizer ] = None ,
108
+ # KV cache quantization
109
+ type_k : Optional [int ] = None ,
110
+ type_v : Optional [int ] = None ,
107
111
# Misc
108
112
verbose : bool = True ,
109
113
# Extra Params
@@ -151,6 +155,7 @@ def __init__(
151
155
n_threads: Number of threads to use for generation
152
156
n_threads_batch: Number of threads to use for batch processing
153
157
rope_scaling_type: RoPE scaling type, from `enum llama_rope_scaling_type`. ref: https://github.com/ggerganov/llama.cpp/pull/2054
158
+ pooling_type: Pooling type, from `enum llama_pooling_type`.
154
159
rope_freq_base: RoPE base frequency, 0 = from model
155
160
rope_freq_scale: RoPE frequency scaling factor, 0 = from model
156
161
yarn_ext_factor: YaRN extrapolation mix factor, negative = from model
@@ -170,6 +175,8 @@ def __init__(
170
175
draft_model: Optional draft model to use for speculative decoding.
171
176
tokenizer: Optional tokenizer to override the default tokenizer from llama.cpp.
172
177
verbose: Print verbose output to stderr.
178
+ type_k: KV cache data type for K (default: f16)
179
+ type_v: KV cache data type for V (default: f16)
173
180
174
181
Raises:
175
182
ValueError: If the model path does not exist.
@@ -271,6 +278,7 @@ def __init__(
271
278
if rope_scaling_type is not None
272
279
else llama_cpp .LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
273
280
)
281
+ self .context_params .pooling_type = pooling_type
274
282
self .context_params .rope_freq_base = (
275
283
rope_freq_base if rope_freq_base != 0.0 else 0
276
284
)
@@ -293,9 +301,13 @@ def __init__(
293
301
self .context_params .logits_all = (
294
302
logits_all if draft_model is None else True
295
303
) # Must be set to True for speculative decoding
296
- self .context_params .embedding = embedding
304
+ self .context_params .embeddings = embedding # TODO: Rename to embeddings
297
305
self .context_params .offload_kqv = offload_kqv
298
-
306
+ # KV cache quantization
307
+ if type_k is not None :
308
+ self .context_params .type_k = type_k
309
+ if type_v is not None :
310
+ self .context_params .type_v = type_v
299
311
# Sampling Params
300
312
self .last_n_tokens_size = last_n_tokens_size
301
313
@@ -787,7 +799,7 @@ def embed(
787
799
n_embd = self .n_embd ()
788
800
n_batch = self .n_batch
789
801
790
- if self .context_params .embedding == False :
802
+ if self .context_params .embeddings == False :
791
803
raise RuntimeError (
792
804
"Llama model must be created with embedding=True to call this method"
793
805
)
@@ -814,9 +826,12 @@ def decode_batch(n_seq: int):
814
826
815
827
# store embeddings
816
828
for i in range (n_seq ):
817
- embedding : List [ float ] = llama_cpp .llama_get_embeddings_ith (
829
+ ptr = llama_cpp .llama_get_embeddings_seq (
818
830
self ._ctx .ctx , i
819
- )[:n_embd ]
831
+ )
832
+ if not ptr :
833
+ raise RuntimeError ("Failed to get embeddings from sequence pooling type is not set" )
834
+ embedding : List [float ] = ptr [:n_embd ]
820
835
if normalize :
821
836
norm = float (np .linalg .norm (embedding ))
822
837
embedding = [v / norm for v in embedding ]
@@ -1647,6 +1662,7 @@ def create_chat_completion(
1647
1662
top_k = top_k ,
1648
1663
min_p = min_p ,
1649
1664
typical_p = typical_p ,
1665
+ logprobs = top_logprobs if logprobs else None ,
1650
1666
stream = stream ,
1651
1667
stop = stop ,
1652
1668
seed = seed ,
@@ -1717,6 +1733,7 @@ def __getstate__(self):
1717
1733
n_threads = self .context_params .n_threads ,
1718
1734
n_threads_batch = self .context_params .n_threads_batch ,
1719
1735
rope_scaling_type = self .context_params .rope_scaling_type ,
1736
+ pooling_type = self .context_params .pooling_type ,
1720
1737
rope_freq_base = self .context_params .rope_freq_base ,
1721
1738
rope_freq_scale = self .context_params .rope_freq_scale ,
1722
1739
yarn_ext_factor = self .context_params .yarn_ext_factor ,
@@ -1725,7 +1742,8 @@ def __getstate__(self):
1725
1742
yarn_beta_slow = self .context_params .yarn_beta_slow ,
1726
1743
yarn_orig_ctx = self .context_params .yarn_orig_ctx ,
1727
1744
logits_all = self .context_params .logits_all ,
1728
- embedding = self .context_params .embedding ,
1745
+ embedding = self .context_params .embeddings ,
1746
+ offload_kqv = self .context_params .offload_kqv ,
1729
1747
# Sampling Params
1730
1748
last_n_tokens_size = self .last_n_tokens_size ,
1731
1749
# LoRA Params
@@ -1737,51 +1755,17 @@ def __getstate__(self):
1737
1755
# Chat Format Params
1738
1756
chat_format = self .chat_format ,
1739
1757
chat_handler = self .chat_handler ,
1758
+ # Speculative Decidng
1759
+ draft_model = self .draft_model ,
1760
+ # KV cache quantization
1761
+ type_k = self .context_params .type_k ,
1762
+ type_v = self .context_params .type_v ,
1740
1763
# Misc
1741
1764
verbose = self .verbose ,
1742
1765
)
1743
1766
1744
1767
def __setstate__ (self , state ):
1745
- self .__init__ (
1746
- model_path = state ["model_path" ],
1747
- # Model Params
1748
- n_gpu_layers = state ["n_gpu_layers" ],
1749
- split_mode = state ["split_mode" ],
1750
- main_gpu = state ["main_gpu" ],
1751
- tensor_split = state ["tensor_split" ],
1752
- vocab_only = state ["vocab_only" ],
1753
- use_mmap = state ["use_mmap" ],
1754
- use_mlock = state ["use_mlock" ],
1755
- kv_overrides = state ["kv_overrides" ],
1756
- # Context Params
1757
- seed = state ["seed" ],
1758
- n_ctx = state ["n_ctx" ],
1759
- n_batch = state ["n_batch" ],
1760
- n_threads = state ["n_threads" ],
1761
- n_threads_batch = state ["n_threads_batch" ],
1762
- rope_freq_base = state ["rope_freq_base" ],
1763
- rope_freq_scale = state ["rope_freq_scale" ],
1764
- rope_scaling_type = state ["rope_scaling_type" ],
1765
- yarn_ext_factor = state ["yarn_ext_factor" ],
1766
- yarn_attn_factor = state ["yarn_attn_factor" ],
1767
- yarn_beta_fast = state ["yarn_beta_fast" ],
1768
- yarn_beta_slow = state ["yarn_beta_slow" ],
1769
- yarn_orig_ctx = state ["yarn_orig_ctx" ],
1770
- logits_all = state ["logits_all" ],
1771
- embedding = state ["embedding" ],
1772
- # Sampling Params
1773
- last_n_tokens_size = state ["last_n_tokens_size" ],
1774
- # LoRA Params
1775
- lora_base = state ["lora_base" ],
1776
- lora_path = state ["lora_path" ],
1777
- # Backend Params
1778
- numa = state ["numa" ],
1779
- # Chat Format Params
1780
- chat_format = state ["chat_format" ],
1781
- chat_handler = state ["chat_handler" ],
1782
- # Misc
1783
- verbose = state ["verbose" ],
1784
- )
1768
+ self .__init__ (** state )
1785
1769
1786
1770
def save_state (self ) -> LlamaState :
1787
1771
assert self ._ctx .ctx is not None
0 commit comments