8000 llama : rework embeddings logic by ggerganov · Pull Request #14208 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

llama : rework embeddings logic #14208

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 16, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
llama : rework embeddings logic
ggml-ci
  • Loading branch information
ggerganov committed Jun 16, 2025
commit e8ddfa30e65627557a6cc41bf277deffa4518459
8 changes: 5 additions & 3 deletions examples/gritlm/gritlm.cpp
10000
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,11 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve

// add input to batch (this increments n_tokens)
for (int32_t j = 0; j < n_toks; j++) {
common_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst);
common_batch_add(batch, inputs[j], j, { 0 }, true);
}

// clear previous kv_cache values (irrelevant for embeddings)
llama_memory_clear(llama_get_memory(ctx), true);
llama_set_embeddings(ctx, true);
llama_set_causal_attn(ctx, false);

// run model
Expand Down Expand Up @@ -103,7 +102,6 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
llama_token eos_token = llama_vocab_eos(vocab);

llama_memory_clear(llama_get_memory(ctx), true);
llama_set_embeddings(ctx, false);
llama_set_causal_attn(ctx, true);

llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
Expand Down Expand Up @@ -166,6 +164,8 @@ int main(int argc, char * argv[]) {
llama_model_params mparams = common_model_params_to_llama(params);
llama_context_params cparams = common_context_params_to_llama(params);

cparams.embeddings = true;

llama_backend_init();

llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
Expand Down Expand Up @@ -213,6 +213,8 @@ int main(int argc, char * argv[]) {
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[1].c_str(), cosine_sim_q1_d1);
}

llama_set_embeddings(ctx, false);

// ### Generation ###
// GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
{
Expand Down
13 changes: 8 additions & 5 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,16 +254,19 @@ extern "C" {
// - seq_id : the sequence to which the respective token belongs
// (if set to NULL, the sequence ID will be assumed to be 0)
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
// (if set to NULL, only the logits for last token will be returned)
// (if set to NULL:
// - if embeddings: all tokens are output
// - if not: only the last token is output
// )
//
typedef struct llama_batch {
int32_t n_tokens;

llama_token * token;
float * embd;
llama_pos * pos;
int32_t * n_seq_id; // TODO: remove, should belong to only 1 sequence
llama_seq_id ** seq_id; // TODO: become llama_seq_id * seq_id;
int32_t * n_seq_id;
llama_seq_id ** seq_id;
int8_t * logits; // TODO: rename this to "output"
} llama_batch;

Expand Down Expand Up @@ -961,8 +964,8 @@ extern "C" {
// Get the number of threads used for prompt and batch processing (multiple token).
LLAMA_API int32_t llama_n_threads_batch(struct llama_context * ctx);

// Set whether the model is in embeddings mode or not
// If true, embeddings will be returned but logits will not
// Set whether the context outputs embeddings or not
// Note: set to true only if the context was created with llama_context_params.embeddings = true
LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);

// Set whether to use causal attention or not
Expand Down
30 changes: 26 additions & 4 deletions src/llama-batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,8 @@ llama_batch_allocr::llama_batch_allocr() {
bool llama_batch_allocr::init(
const llama_batch & batch_inp,
const llama_vocab & vocab,
const llama_memory_i * memory) {
const llama_memory_i * memory,
bool embd_all) {
clear();

batch = batch_inp;
Expand Down Expand Up @@ -378,10 +379,31 @@ bool llama_batch_allocr::init(
}

if (!batch.logits) {
// by default return the output only for the last token
output.resize(batch.n_tokens);
output[output.size() - 1] = true;
if (embd_all) {
// return the output for all tokens
output.resize(batch.n_tokens, true);
} else {
// return the output only for the last token
output.resize(batch.n_tokens, false);
output[output.size() - 1] = true;
}

batch.logits = output.data();
} else if (embd_all) {
bool warn = false;

for (int32_t i = 0; i < batch.n_tokens; ++i) {
if (batch.logits[i] == 0) {
warn = true;
}
}

if (warn) {
LLAMA_LOG_WARN("%s: embeddings required but some input tokens were not marked as outputs -> overriding\n", __func__);

output.resize(batch.n_tokens, true);
batch.logits = output.data();
}
}

//
Expand Down
3 changes: 2 additions & 1 deletion src/llama-batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ class llama_batch_allocr {
bool init(
const llama_batch & batch_inp,
const llama_vocab & vocab,
const llama_memory_i * memory);
const llama_memory_i * memory,
bool embd_all);

const llama_batch & get_batch() const;

Expand Down
32 changes: 17 additions & 15 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ llama_context::llama_context(
}
}

if (!cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE) {
LLAMA_LOG_WARN("%s: pooling_type is set to %d but embeddings is set to false - disabling pooling\n", __func__, cparams.pooling_type);

cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
}

if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
cparams.causal_attn = hparams.causal_attn;
} else {
Expand Down Expand Up @@ -728,7 +734,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
}

// note: during encode, we always pass the full sequence starting from pos = 0
if (!batch_allocr->init(batch_inp, model.vocab, nullptr)) {
if (!batch_allocr->init(batch_inp, model.vocab, nullptr, true)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1;
}
Expand Down Expand Up @@ -894,7 +900,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
return -1;
}

if (!batch_allocr->init(batch_inp, model.vocab, memory.get())) {
// when computing embeddings, all tokens are output
const bool embd_all = cparams.embeddings;

if (!batch_allocr->init(batch_inp, model.vocab, memory.get(), embd_all)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1;
}
Expand All @@ -911,12 +920,9 @@ int llama_context::decode(const llama_batch & batch_inp) {

GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT

// this indicates we are doing pooled embedding
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;

const uint32_t n_outputs_all = batch_allocr->get_n_outputs();

if (embd_pooled) {
if (embd_all) {
// require that all tokens are output
if (n_outputs_all != n_tokens_all) {
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
Expand Down Expand Up @@ -945,7 +951,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
llama_memory_state_ptr mstate;

while (true) {
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_all);
if (!mstate) {
return -2;
}
Expand Down Expand Up @@ -1058,7 +1064,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
//}

auto * t_logits = cparams.embeddings ? nullptr : res->get_logits();
auto * t_logits = res->get_logits();
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;

if (t_embd && res->get_embd_pooled()) {
Expand Down Expand Up @@ -1222,9 +1228,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
const auto n_vocab = vocab.n_tokens();
const auto n_embd = hparams.n_embd;

// TODO: use a per-batch flag for logits presence instead
bool has_logits = !cparams.embeddings;
bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
bool has_logits = true;
bool has_embd = cparams.embeddings;

// TODO: hacky enc-dec support
if (model.arch == LLM_ARCH_T5) {
Expand Down Expand Up @@ -2044,14 +2049,11 @@ void llama_context::opt_epoch_iter(

n_queued_tokens += n_tokens_all;

// this indicates we are doing pooled embedding
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;

embd_seq.clear();

uint32_t n_outputs_all = n_tokens_all;

auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
auto mstate = memory->init_batch(batch, cparams.n_ubatch, true);
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
break;
Expand Down
8 changes: 3 additions & 5 deletions src/llama-kv-cache-recurrent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,18 +359,16 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
return result;
}

llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
GGML_UNUSED(embd_pooled);

llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);

std::vector<llama_ubatch> ubatches;

while (sbatch.n_tokens > 0) {
llama_ubatch ubatch;

if (embd_pooled) {
// Pooled embeddings cannot be split across ubatches (yet)
if (embd_all) {
// if all tokens are output, split by sequence
ubatch = sbatch.split_seq(n_ubatch);
} else {
ubatch = sbatch.split_equal(n_ubatch);
Expand Down
2 changes: 1 addition & 1 deletion src/llama-kv-cache-recurrent.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class llama_kv_cache_recurrent : public llama_memory_i {
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled) override;
bool embd_all) override;

llama_memory_state_ptr init_full() override;

Expand Down
4 changes: 2 additions & 2 deletions src/llama-kv-cache-unified-iswa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
return kv_swa->seq_pos_max(seq_id);
}

llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
GGML_UNUSED(embd_pooled);
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
GGML_UNUSED(embd_all);

// first try simple split
do {
Expand Down
2 changes: 1 addition & 1 deletion src/llama-kv-cache-unified-iswa.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class llama_kv_cache_unified_iswa : public llama_memory_i {
< E41F /td> llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled) override;
bool embd_all) override;

llama_memory_state_ptr init_full() override;

Expand Down
4 changes: 2 additions & 2 deletions src/llama-kv-cache-unified.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,8 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
llama_memory_state_ptr llama_kv_cache_unified::init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled) {
GGML_UNUSED(embd_pooled);
bool embd_all) {
GGML_UNUSED(embd_all);

do {
auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
Expand Down
2 changes: 1 addition & 1 deletion src/llama-kv-cache-unified.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class llama_kv_cache_unified : public llama_memory_i {
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled) override;
bool embd_all) override;

llama_memory_state_ptr init_full() override;

Expand Down
3 changes: 2 additions & 1 deletion src/llama-memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,11 @@ struct llama_memory_i {
// split the input batch in 702B to a set of ubatches and verify that they can fit into the cache
// return a state object containing the ubatches and KV cache state required to process them
// check the llama_memory_state_i::get_status() for the result
// TODO: remove embd_all argument
virtual llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled) = 0;
bool embd_all) = 0;

// simulate full cache, used for allocating worst-case compute buffers
virtual llama_memory_state_ptr init_full() = 0;
Expand Down
Loading
Loading
0