@@ -470,6 +470,7 @@ class llama_model_params(Structure):
470
470
# bool logits_all; // the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
471
471
# bool embedding; // embedding mode only
472
472
# bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
473
+ # bool do_pooling; // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
473
474
# };
474
475
class llama_context_params (Structure ):
475
476
"""Parameters for llama_context
@@ -496,6 +497,7 @@ class llama_context_params(Structure):
496
497
logits_all (bool): the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
497
498
embedding (bool): embedding mode only
498
499
offload_kqv (bool): whether to offload the KQV ops (including the KV cache) to GPU
500
+ do_pooling (bool): whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
499
501
"""
500
502
501
503
_fields_ = [
@@ -520,6 +522,7 @@ class llama_context_params(Structure):
520
522
("logits_all" , c_bool ),
521
523
("embedding" , c_bool ),
522
524
("offload_kqv" , c_bool ),
525
+ ("do_pooling" , c_bool ),
523
526
]
524
527
525
528
@@ -1699,6 +1702,21 @@ def llama_get_embeddings(
1699
1702
_lib .llama_get_embeddings .restype = c_float_p
1700
1703
1701
1704
1705
+ # // Get the embeddings for the ith sequence
1706
+ # // llama_get_embeddings(ctx) + i*n_embd
1707
+ # LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
1708
+ def llama_get_embeddings_ith (
1709
+ ctx : llama_context_p , i : Union [c_int32 , int ]
1710
+ ): # type: (...) -> Array[float] # type: ignore
1711
+ """Get the embeddings for the ith sequence
1712
+ llama_get_embeddings(ctx) + i*n_embd"""
1713
+ return _lib .llama_get_embeddings_ith (ctx , i )
1714
+
1715
+
1716
+ _lib .llama_get_embeddings_ith .argtypes = [llama_context_p , c_int32 ]
1717
+ _lib .llama_get_embeddings_ith .restype = c_float_p
1718
+
1719
+
1702
1720
# //
1703
1721
# // Vocab
1704
1722
# //
0 commit comments