8000 Update llama.cpp · abetlen/llama-cpp-python@f7cdf78 · GitHub
[go: up one dir, main page]

Skip to content

Commit f7cdf78

Browse files
committed
Update llama.cpp
1 parent 68fb71b commit f7cdf78

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

llama_cpp/llama_cpp.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,7 @@ class llama_model_params(Structure):
470470
# bool logits_all; // the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
471471
# bool embedding; // embedding mode only
472472
# bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
473+
# bool do_pooling; // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
473474
# };
474475
class llama_context_params(Structure):
475476
"""Parameters for llama_context
@@ -496,6 +497,7 @@ class llama_context_params(Structure):
496497
logits_all (bool): the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
497498
embedding (bool): embedding mode only
498499
offload_kqv (bool): whether to offload the KQV ops (including the KV cache) to GPU
500+
do_pooling (bool): whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
499501
"""
500502

501503
_fields_ = [
@@ -520,6 +522,7 @@ class llama_context_params(Structure):
520522
("logits_all", c_bool),
521523
("embedding", c_bool),
522524
("offload_kqv", c_bool),
525+
("do_pooling", c_bool),
523526
]
524527

525528

@@ -1699,6 +1702,21 @@ def llama_get_embeddings(
16991702
_lib.llama_get_embeddings.restype = c_float_p
17001703

17011704

1705+
# // Get the embeddings for the ith sequence
1706+
# // llama_get_embeddings(ctx) + i*n_embd
1707+
# LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
1708+
def llama_get_embeddings_ith(
1709+
ctx: llama_context_p, i: Union[c_int32, int]
1710+
): # type: (...) -> Array[float] # type: ignore
1711+
"""Get the embeddings for the ith sequence
1712+
llama_get_embeddings(ctx) + i*n_embd"""
1713+
return _lib.llama_get_embeddings_ith(ctx, i)
1714+
1715+
1716+
_lib.llama_get_embeddings_ith.argtypes = [llama_context_p, c_int32]
1717+
_lib.llama_get_embeddings_ith.restype = c_float_p
1718+
1719+
17021720
# //
17031721
# // Vocab
17041722
# //

vendor/llama.cpp

0 commit comments

Comments
 (0)
0