8000 fix: Set default pooling_type to mean, check for null pointer. · coderonion/llama-cpp-python@d318cc8 · GitHub
[go: up one dir, main page]

Skip to content

Commit d318cc8

Browse files
committed
fix: Set default pooling_type to mean, check for null pointer.
1 parent dd0ee56 commit d318cc8

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

llama_cpp/llama.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def __init__(
7979
n_threads: Optional[int] = None,
8080
n_threads_batch: Optional[int] = None,
8181
rope_scaling_type: Optional[int] = llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
82+
pooling_type: int = llama_cpp.LLAMA_POOLING_TYPE_MEAN,
8283
rope_freq_base: float = 0.0,
8384
rope_freq_scale: float = 0.0,
8485
yarn_ext_factor: float = -1.0,
@@ -151,6 +152,7 @@ def __init__(
151152
n_threads: Number of threads to use for generation
152153
n_threads_batch: Number of threads to use for batch processing
153154
rope_scaling_type: RoPE scaling type, from `enum llama_rope_scaling_type`. ref: https://github.com/ggerganov/llama.cpp/pull/2054
155+
pooling_type: Pooling type, from `enum llama_pooling_type`.
154156
rope_freq_base: RoPE base frequency, 0 = from model
155157
rope_freq_scale: RoPE frequency scaling factor, 0 = from model
156158
yarn_ext_factor: YaRN extrapolation mix factor, negative = from model
@@ -271,6 +273,7 @@ def __init__(
271273
if rope_scaling_type is not None
272274
else llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
273275
)
276+
self.context_params.pooling_type = pooling_type
274277
self.context_params.rope_freq_base = (
275278
rope_freq_base if rope_freq_base != 0.0 else 0
276279
)
@@ -814,9 +817,12 @@ def decode_batch(n_seq: int):
814817

815818
# store embeddings
816819
for i in range(n_seq):
817-
embedding: List[float] = llama_cpp.llama_get_embeddings_seq(
820+
ptr = llama_cpp.llama_get_embeddings_seq(
818821
self._ctx.ctx, i
819-
)[:n_embd]
822+
)
823+
if not ptr:
824+
raise RuntimeError("Failed to get embeddings from sequence pooling type is not set")
825+
embedding: List[float] = ptr[:n_embd]
820826
if normalize:
821827
norm = float(np.linalg.norm(embedding))
822828
embedding = [v / norm for v in embedding]

llama_cpp/llama_cpp.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,6 @@ class llama_model_params(ctypes.Structure):
579579
# bool embeddings; // if true, extract embeddings (together with logits)
580580
# bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
581581

582-
583582
# // Abort callback
584583
# // if it returns true, execution of llama_decode() will be aborted
585584
# // currently works only with CPU execution

0 commit comments

Comments
 (0)
0