@@ -79,6 +79,7 @@ def __init__(
7979 n_threads : Optional [int ] = None ,
8080 n_threads_batch : Optional [int ] = None ,
8181 rope_scaling_type : Optional [int ] = llama_cpp .LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED ,
82+ pooling_type : int = llama_cpp .LLAMA_POOLING_TYPE_UNSPECIFIED ,
8283 rope_freq_base : float = 0.0 ,
8384 rope_freq_scale : float = 0.0 ,
8485 yarn_ext_factor : float = - 1.0 ,
@@ -104,6 +105,9 @@ def __init__(
104105 draft_model : Optional [LlamaDraftModel ] = None ,
105106 # Tokenizer Override
106107 tokenizer : Optional [BaseLlamaTokenizer ] = None ,
108+ # KV cache quantization
109+ type_k : Optional [int ] = None ,
110+ type_v : Optional [int ] = None ,
107111 # Misc
108112 verbose : bool = True ,
109113 # Extra Params
@@ -151,6 +155,7 @@ def __init__(
151155 n_threads: Number of threads to use for generation
152156 n_threads_batch: Number of threads to use for batch processing
153157 rope_scaling_type: RoPE scaling type, from `enum llama_rope_scaling_type`. ref: https://github.com/ggerganov/llama.cpp/pull/2054
158+ pooling_type: Pooling type, from `enum llama_pooling_type`.
154159 rope_freq_base: RoPE base frequency, 0 = from model
155160 rope_freq_scale: RoPE frequency scaling factor, 0 = from model
156161 yarn_ext_factor: YaRN extrapolation mix factor, negative = from model
@@ -170,6 +175,8 @@ def __init__(
170175 draft_model: Optional draft model to use for speculative decoding.
171176 tokenizer: Optional tokenizer to override the default tokenizer from llama.cpp.
172177 verbose: Print verbose output to stderr.
178+ type_k: KV cache data type for K (default: f16)
179+ type_v: KV cache data type for V (default: f16)
173180
174181 Raises:
175182 ValueError: If the model path does not exist.
@@ -271,6 +278,7 @@ def __init__(
271278 if rope_scaling_type is not None
272279 else llama_cpp .LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
273280 )
281+ self .context_params .pooling_type = pooling_type
274282 self .context_params .rope_freq_base = (
275283 rope_freq_base if rope_freq_base != 0.0 else 0
276284 )
@@ -293,9 +301,13 @@ def __init__(
293301 self .context_params .logits_all = (
294302 logits_all if draft_model is None else True
295303 ) # Must be set to True for speculative decoding
296- self .context_params .embedding = embedding
304+ self .context_params .embeddings = embedding # TODO: Rename to embeddings
297305 self .context_params .offload_kqv = offload_kqv
298-
306+ # KV cache quantization
307+ if type_k is not None :
308+ self .context_params .type_k = type_k
309+ if type_v is not None :
310+ self .context_params .type_v = type_v
299311 # Sampling Params
300312 self .last_n_tokens_size = last_n_tokens_size
301313
@@ -787,7 +799,7 @@ def embed(
787799 n_embd = self .n_embd ()
788800 n_batch = self .n_batch
789801
790- if self .context_params .embedding == False :
802+ if self .context_params .embeddings == False :
791803 raise RuntimeError (
792804 "Llama model must be created with embedding=True to call this method"
793805 )
@@ -814,9 +826,12 @@ def decode_batch(n_seq: int):
814826
815827 # store embeddings
816828 for i in range (n_seq ):
817- embedding : List [ float ] = llama_cpp .llama_get_embeddings_ith (
829+ ptr = llama_cpp .llama_get_embeddings_seq (
818830 self ._ctx .ctx , i
819- )[:n_embd ]
831+ )
832+ if not ptr :
833+ raise RuntimeError ("Failed to get embeddings from sequence pooling type is not set" )
834+ embedding : List [float ] = ptr [:n_embd ]
820835 if normalize :
821836 norm = float (np .linalg .norm (embedding ))
822837 embedding = [v / norm for v in embedding ]
@@ -1647,6 +1662,7 @@ def create_chat_completion(
16471662 top_k = top_k ,
16481663 min_p = min_p ,
16491664 typical_p = typical_p ,
1665+ logprobs = top_logprobs if logprobs else None ,
16501666 stream = stream ,
16511667 stop = stop ,
16521668 seed = seed ,
@@ -1717,6 +1733,7 @@ def __getstate__(self):
17171733 n_threads = self .context_params .n_threads ,
17181734 n_threads_batch = self .context_params .n_threads_batch ,
17191735 rope_scaling_type = self .context_params .rope_scaling_type ,
1736+ pooling_type = self .context_params .pooling_type ,
17201737 rope_freq_base = self .context_params .rope_freq_base ,
17211738 rope_freq_scale = self .context_params .rope_freq_scale ,
17221739 yarn_ext_factor = self .context_params .yarn_ext_factor ,
@@ -1725,7 +1742,8 @@ def __getstate__(self):
17251742 yarn_beta_slow = self .context_params .yarn_beta_slow ,
17261743 yarn_orig_ctx = self .context_params .yarn_orig_ctx ,
17271744 logits_all = self .context_params .logits_all ,
1728- embedding = self .context_params .embedding ,
1745+ embedding = self .context_params .embeddings ,
1746+ offload_kqv = self .context_params .offload_kqv ,
17291747 # Sampling Params
17301748 last_n_tokens_size = self .last_n_tokens_size ,
17311749 # LoRA Params
@@ -1737,51 +1755,17 @@ def __getstate__(self):
17371755 # Chat Format Params
17381756 chat_format = self .chat_format ,
17391757 chat_handler = self .chat_handler ,
1758+ # Speculative Decidng
1759+ draft_model = self .draft_model ,
1760+ # KV cache quantization
1761+ type_k = self .context_params .type_k ,
1762+ type_v = self .context_params .type_v ,
17401763 # Misc
17411764 verbose = self .verbose ,
17421765 )
17431766
17441767 def __setstate__ (self , state ):
1745- self .__init__ (
1746- model_path = state ["model_path" ],
1747- # Model Params
1748- n_gpu_layers = state ["n_gpu_layers" ],
1749- split_mode = state ["split_mode" ],
1750- main_gpu = state ["main_gpu" ],
1751- tensor_split = state ["tensor_split" ],
1752- vocab_only = state ["vocab_only" ],
1753- use_mmap = state ["use_mmap" ],
1754- use_mlock = state ["use_mlock" ],
1755- kv_overrides = state ["kv_overrides" ],
1756- # Context Params
1757- seed = state ["seed" ],
1758- n_ctx = state ["n_ctx" ],
1759- n_batch = state ["n_batch" ],
1760- n_threads = state ["n_threads" ],
1761- n_threads_batch = state ["n_threads_batch" ],
1762- rope_freq_base = state ["rope_freq_base" ],
1763- rope_freq_scale = state ["rope_freq_scale" ],
1764- rope_scaling_type = state ["rope_scaling_type" ],
1765- yarn_ext_factor = state ["yarn_ext_factor" ],
1766- yarn_attn_factor = state ["yarn_attn_factor" ],
1767- yarn_beta_fast = state ["yarn_beta_fast" ],
1768- yarn_beta_slow = state ["yarn_beta_slow" ],
1769- yarn_orig_ctx = state ["yarn_orig_ctx" ],
1770- logits_all = state ["logits_all" ],
1771- embedding = state ["embedding" ],
1772- # Sampling Params
1773- last_n_tokens_size = state ["last_n_tokens_size" ],
1774- # LoRA Params
1775- lora_base = state ["lora_base" ],
1776- lora_path = state ["lora_path" ],
1777- # Backend Params
1778- numa = state ["numa" ],
1779- # Chat Format Params
1780- chat_format = state ["chat_format" ],
1781- chat_handler = state ["chat_handler" ],
1782- # Misc
1783- verbose = state ["verbose" ],
1784- )
1768+ self .__init__ (** state )
17851769
17861770 def save_state (self ) -> LlamaState :
17871771 assert self ._ctx .ctx is not None
0 commit comments