10000 llama : add reranking support by ggerganov · Pull Request #9510 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

llama : add reranking support #9510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Sep 28, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
3453e62
py : add XLMRobertaForSequenceClassification [no ci]
ggerganov Sep 16, 2024
77723ed
py : fix scalar-tensor conversion [no ci]
ggerganov Sep 17, 2024
49f90de
py : fix position embeddings chop [no ci]
ggerganov Sep 17, 2024
dc0cdd8
llama : read new cls tensors [no ci]
ggerganov Sep 17, 2024
d0a7bf9
llama : add classigication head (wip) [no ci]
ggerganov Sep 18, 2024
125a067
llama : add "rank" pooling type
ggerganov Sep 19, 2024
6235c62
server : add rerank endpoint
ggerganov Sep 19, 2024
6916ed1
llama : aboud ggml_repeat during classification
ggerganov Sep 23, 2024
62a45d1
rerank : cleanup + comments
ggerganov Sep 25, 2024
7bde9a0
server : accept /rerank endpoint in addition to /v1/rerank [no ci]
ggerganov Sep 25, 2024
c62a39d
embedding : parse special tokens
ggerganov Sep 25, 2024
866c011
jina : support v1 reranker
ggerganov Sep 25, 2024
84f56f3
vocab : minor style
ggerganov Sep 25, 2024
00b3376
server : initiate tests for later
ggerganov Sep 26, 2024
877a04c
server : add docs
ggerganov Sep 26, 2024
4d45775
llama : add comment [no ci]
ggerganov Sep 26, 2024
ca99a6c
llama : fix uninitialized tensors
ggerganov Sep 26, 2024
f19554f
ci : add rerank tests
ggerganov Sep 26, 2024
f27dd69
add reranking test
ngxson Sep 26, 2024
1ae8376
change test data
ngxson Sep 26, 2024
84b0af8
Update examples/server/server.cpp
ggerganov Sep 27, 2024
0d6f6a7
add `--reranking` argument
ngxson Sep 27, 2024
a4ac45f
update server docs
ngxson Sep 27, 2024
39167b6
llama : fix comment [no ci]
ggerganov Sep 28, 2024
aeac876
Merge branch 'master' into gg/rerank
ggerganov Sep 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
rerank : cleanup + comments
  • Loading branch information
ggerganov committed Sep 25, 2024
commit 62a45d12ef4b42d5d5c0172e19ef41b17ba71a09
2 changes: 1 addition & 1 deletion examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ int main(int argc, char ** argv) {
}
} else if (pooling_type == LLAMA_POOLING_TYPE_RANK) {
for (int j = 0; j < n_embd_count; j++) {
LOG("rank score %d: %8.3f\n", j, emb[j * n_embd]);
LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
}
} else {
// print the first part of the embeddings or for a single prompt, the full embedding
Expand Down
16 changes: 11 additions & 5 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1419,7 +1419,7 @@ struct server_context {
queue_results.send(res);
}

void send_rank(const server_slot & slot, const llama_batch & batch) {
void send_rerank(const server_slot & slot, const llama_batch & batch) {
server_task_result res;
res.id = slot.id_task;
res.error = false;
Expand All @@ -1440,19 +1440,19 @@ struct server_context {

res.data = json {
{"index", slot.index},
{"rank", -1e6},
{"score", -1e6},
};

continue;
}

res.data = json {
{"index", slot.index},
{"rank", embd[0]},
{"score", embd[0]},
};
}

SLT_DBG(slot, "sending rank, res = '%s'\n", res.data.dump().c_str());
SLT_DBG(slot, "sending rerank result, res = '%s'\n", res.data.dump().c_str());

queue_results.send(res);
}
Expand Down Expand Up @@ -1493,6 +1493,9 @@ struct server_context {
else if (prompt.is_array()) {
std::vector<json> prompts = prompt;
if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
// prompts[0] is the question
// the rest are the answers/documents
SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) prompts.size() - 1);
for (size_t i = 1; i < prompts.size(); i++) {
json qd;
qd.push_back(prompts[0]);
Expand All @@ -1501,6 +1504,7 @@ struct server_context {
create_task(data, true, qd);
}
} else {
SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) prompts.size());
for (size_t i = 0; i < prompts.size(); i++) {
const auto & e = prompts[i];
if (e.is_string() || json_is_array_of_numbers(e)) {
Expand Down Expand Up @@ -1965,6 +1969,7 @@ struct server_context {
// track if this is an embedding or non-embedding batch
// if we've added sampled tokens above, we are in non-embedding mode
// -1: none, 0: non-embedding, 1: embedding
// TODO: make enum
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;

// next, batch any pending prompts without exceeding n_batch
Expand Down Expand Up @@ -2133,6 +2138,7 @@ struct server_context {
slot.n_prompt_tokens_processed = 0;
}

// non-causal tasks require to fit the entire prompt in the physical batch
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
// cannot fit the prompt in the current batch - will try next iter
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
Expand Down Expand Up @@ -2318,7 +2324,7 @@ struct server_context {
}

if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
send_rank(slot, batch_view);
send_rerank(slot, batch_view);
slot.release();
slot.i_batch = -1;
continue; // continue loop of slots
Expand Down
2 changes: 1 addition & 1 deletion examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ static json format_response_rerank(const json & request, const json & ranks) {
for (const auto & rank : ranks) {
data.push_back(json{
{"index", i++},
{"relevance_score", json_value(rank, "rank", 0.0)},
{"relevance_score", json_value(rank, "score", 0.0)},
});
}

Expand Down
11 changes: 6 additions & 5 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ extern "C" {
LLAMA_POOLING_TYPE_MEAN = 1,
LLAMA_POOLING_TYPE_CLS = 2,
LLAMA_POOLING_TYPE_LAST = 3,
LLAMA_POOLING_TYPE_RANK = 4,
LLAMA_POOLING_TYPE_RANK = 4, // used by reranking models to attach the classification head to the graph
};

enum llama_attention_type {
Expand All @@ -202,9 +202,9 @@ extern "C" {
};

enum llama_split_mode {
LLAMA_SPLIT_MODE_NONE = 0, // single GPU
LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
LLAMA_SPLIT_MODE_NONE = 0, // single GPU
LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
};

// TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979)
Expand Down Expand Up @@ -872,7 +872,8 @@ extern "C" {

// Get the embeddings for a sequence id
// Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
// shape: [n_embd] (1-dimensional)
// when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence
// otherwise: float[n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);

//
Expand Down
10 changes: 8 additions & 2 deletions src/llama.cpp
B0FB
Original file line number Diff line number Diff line change
Expand Up @@ -17009,7 +17009,7 @@ static int llama_decode_internal(
} break;
case LLAMA_POOLING_TYPE_RANK:
{
// extract the rank score - a single float per sequence
// extract the rerank score - a single float per sequence
auto & embd_seq_out = lctx.embd_seq;

for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
Expand Down Expand Up @@ -17211,7 +17211,6 @@ static int llama_encode_internal(
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST:
case LLAMA_POOLING_TYPE_RANK:
{
// extract sequence embeddings
auto & embd_seq_out = lctx.embd_seq;
Expand All @@ -17228,6 +17227,13 @@ static int llama_encode_internal(
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_RANK:
{
// TODO: this likely should be the same logic as in llama_decoder_internal, but better to
// wait for an encoder model that requires this pooling type in order to test it
// https://github.com/ggerganov/llama.cpp/pull/9510
GGML_ABORT("RANK pooling not implemented yet");
}
case LLAMA_POOLING_TYPE_UNSPECIFIED:
{
GGML_ABORT("unknown pooling type");
Expand Down
Loading
0