8000 llama : support Jamba hybrid Transformer-Mamba models by compilade · Pull Request #7531 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

llama : support Jamba hybrid Transformer-Mamba models #7531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 61 commits into from
Jul 9, 2025
Merged
Changes from 1 commit
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
271104c
wip: llama : separate recurrent states from the KV cache
compilade Apr 3, 2024
8db1e4d
llama : use std::find for seq_nodes in llama_rs_cache
compilade Apr 4, 2024
0028010
llama : state checkpoints for recurrent models
compilade Apr 8, 2024
0c8b3b2
llama : correctly handle more edge cases for the rs cache
compilade Apr 9, 2024
d66849f
Merge branch 'master' into compilade/refactor-kv-cache
compilade Apr 10, 2024
a09db95
llama : rename many llama_kv_cache_* functions
compilade Apr 29, 2024
c460ff1
Merge branch 'master' into compilade/refactor-kv-cache
compilade Apr 29, 2024
b6fafd1
llama : remove useless return value for some llama_cache_* functions
compilade Apr 29, 2024
b7ec12e
Merge branch 'master' into compilade/refactor-kv-cache
compilade May 12, 2024
3b57b55
Merge branch 'master' into compilade/refactor-kv-cache
compilade May 22, 2024
7e13f19
llama : rethink recurrent state cell counts
compilade May 24, 2024
cbc743e
llama : support Jamba
compilade May 24, 2024
0fd13e9
Merge branch 'master' into compilade/refactor-kv-cache
compilade May 24, 2024
61a88a1
llama : fix BERT inference without KV cache
compilade May 25, 2024
ea2e63e
convert-hf : check for unprocessed Jamba experts
compilade May 25, 2024
fc59407
convert-hf : support Mini-Jamba conversion
compilade May 25, 2024
181dadf
llama : fix Jamba quantization sanity checks
compilade May 28, 2024
3a414b0
llama : sequence-length-aware batch splitting
compilade May 28, 2024
4e4c41e
Merge branch 'master' into compilade/refactor-kv-cache
compilade May 28, 2024
3587a94
llama : use equal-sequence-length sub-batches for recurrent models
compilade Jun 1, 2024
5d3c7b9
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jun 1, 2024
72eea49
llama : fix batch split output count for embeddings
compilade Jun 1, 2024
18d1c14
llama : minimize swaps when reordering logits
compilade Jun 1, 2024
61200ef
llama : fix edge case finding batch seq_id of split recurrent cell
compilade Jun 1, 2024
eb589d5
llama : avoid copies for simple batch splits
compilade Jun 2, 2024
8fb57ac
llama : use im2col and mul_mat to perform convolution for Mamba
compilade Jun 3, 2024
17f6c1e
llama : fix .base() compilation error on Windows
compilade Jun 3, 2024
fee3c1d
llama : allow doing the equivalent of SSM_CONV with SUM_ROWS and MUL
compilade Jun 3, 2024
6840ac0
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jun 8, 2024
372482d
llama : rename llama_cache to llama_past
compilade Jun 8, 2024
43d8d4b
examples : replace llama_kv_cache_seq_* with llama_past_seq_*
compilade Jun 10, 2024
ff794f5
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jun 12, 2024
33425a7
mamba : fix non-contiguous usage of ggml_silu
compilade Jun 12, 2024
10c3c41
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jun 30, 2024
9b38f8b
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jul 4, 2024
bc320ef
Merge branch 'master' into compilade/refactor-kv-cache
compilade Sep 1, 2024
fcb889c
llama : session saving and reloading for hybrid models
compilade Sep 2, 2024
a03e32a
Merge branch 'master' into compilade/refactor-kv-cache
compilade Sep 2, 2024
9d3f44d
convert_hf : fix Jamba conversion
compilade Sep 2, 2024
5f62db7
llama : fix mixed signedness comparison
compilade Sep 2, 2024
375de5b
llama : use unused n_embd_k_gqa in k_shift
compilade Sep 2, 2024
4bb4b22
llama : begin renaming llama_past back to llama_kv_cache
compilade Sep 14, 2024
63ac36b
Merge branch 'master' into compilade/refactor-kv-cache
compilade Sep 14, 2024
124c222
Merge branch 'master' into compilade/refactor-kv-cache
compilade Oct 12, 2024
8006f3b
llama : remove implicit recurrent state rollbacks
compilade Nov 25, 2024
691698e
Merge branch 'master' into compilade/refactor-kv-cache
compilade Nov 25, 2024
e3fe612
llama : partially apply clang-format style
compilade Nov 25, 2024
2bcaf64
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jul 3, 2025
908e655
convert : fix jamba conv1d shape squeezing
compilade Jul 3, 2025
4682e21
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jul 3, 2025
20f8e43
graph : add back hybrid memory graph input
compilade Jul 3, 2025
07c252f
model : add Jamba to Mamba-specific hparams printing
compilade Jul 3, 2025
f716358
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jul 7, 2025
b0b280e
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jul 8, 2025
db5ff0c
jamba : remove redundant nullptr initializations
compilade Jul 8, 2025
2f39cd7
model : remove unnecessary prefix for tensor loading constants
compilade Jul 8, 2025
f7c7a92
model : use ggml_swiglu_split for Mamba
compilade Jul 8, 2025
a60a24b
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jul 9, 2025
7f3955a
model : make falcon-h1 use shared mamba2 layer builder
compilade Jul 9, 2025
452207f
memory : avoid referring to KV in recurrent cache logs
compilade Jul 9, 2025
4d6a179
gguf-py : avoid adding duplicate tensor mappings for Jamba
compilade Jul 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
llama : use std::find for seq_nodes in llama_rs_cache
  • Loading branch information
compilade committed Apr 4, 2024
commit 8db1e4d45fb27a5e76ac55559a008a425e00fbac
153 changes: 61 additions & 92 deletions llama.cpp
8000
Original file line number Diff line number Diff line change
Expand Up @@ -1962,11 +1962,12 @@ struct llama_rs_seq_node {
llama_seq_id seq_id = -1;
int32_t next_cell = -1;

// needed for automatic typecasting with .find()
// needed for automatic typecasting from a llama_seq_id
llama_rs_seq_node(const llama_seq_id s = -1, int32_t i = -1) : seq_id(s), next_cell(i) {}

bool operator<(const llama_rs_seq_node & other) const {
return seq_id < other.seq_id;
// needed for more convenient std::find
bool operator==(const llama_rs_seq_node & other) const {
return seq_id == other.seq_id;
}

bool is_tail() const {
Expand All @@ -1989,48 +1990,18 @@ struct llama_rs_cell {
// seq_ids by insertion order, to simplify updating n_cells compared to a set
std::vector<llama_rs_seq_node> seq_nodes;

llama_rs_seq_node * get_node(const llama_seq_id & id) {
for (size_t i = 0; i < seq_nodes.size(); ++i) {
if (seq_nodes[i].seq_id == id) {
return &seq_nodes[i];
}
}
return nullptr;
}

void insert_node(const llama_rs_seq_node & node) {
llama_rs_seq_node * node_dest = get_node(node.seq_id);
if (node_dest == nullptr) {
auto node_dest = std::find(seq_nodes.begin(), seq_nodes.end(), node);
if (node_dest == seq_nodes.end()) {
seq_nodes.push_back(node);
} else {
// overwrite the pre-existing node with the same seq_id if it exists
*node_dest = node;
}
}

bool remove_node(llama_rs_seq_node * node_ptr) {
if (node_ptr != nullptr && seq_nodes.data() <= node_ptr) {
size_t offset = node_ptr - seq_nodes.data();
if (offset % sizeof(llama_rs_seq_node) == 0) {
offset /= sizeof(llama_rs_seq_node);
if (offset < seq_nodes.size()) {
for (size_t i = offset + 1; i < seq_nodes.size(); ++i) {
seq_nodes[i - 1] = seq_nodes[i];
}
seq_nodes.resize(seq_nodes.size() - 1);
return true;
}
}
}
return false;
}

bool has_seq_id(const llama_seq_id & id) const {
for (size_t i = 0; i < seq_nodes.size(); ++i) {
if (seq_nodes[i].seq_id == id) {
return true;
}
}
return false;
return std::find(seq_nodes.begin(), seq_nodes.end(), id) != seq_nodes.end();
}

bool is_empty() const {
Expand Down Expand Up @@ -2132,67 +2103,65 @@ struct llama_rs_cache {
bool remove_seq_from_cell(uint32_t i_cell, const llama_seq_id & id) {
if (i_cell < size && (size_t) id < size) {
llama_rs_cell & rs_cell = cells[i_cell];
auto * node_ptr = rs_cell.get_node(id); // search once
if (node_ptr != nullptr) {
auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), id); // search once
if (node_iter != rs_cell.seq_nodes.end()) {
if (rs_cell.seq_nodes.size() == 1) {
return clear_cell(i_cell);
} else {
// update tree
llama_rs_seq_node node = *node_ptr;
if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) {
cells[node.next_cell].prev = rs_cell.prev;
}
// else update tree
llama_rs_seq_node node = *node_iter;
if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) {
cells[node.next_cell].prev = rs_cell.prev;
}
if ((uint32_t) node.seq_id < seq_tails.size()) {
auto & seq = seq_tails[node.seq_id];
bool other_no_longer_shared = rs_cell.seq_nodes.size() == 2;
if (node.is_tail()) {
seq.tail = rs_cell.prev;
if (seq.tail >= 0 && (uint32_t) seq.tail < size) {
llama_rs_cell & new_tail = cells[seq.tail];
new_tail.insert_node(node.seq_id); // ensures next_cell == -1
new_tail.tail_rc += 1;
seq.shared = cells[seq.tail].seq_nodes.size() > 1;
} else {
seq.shared = false;
}
GGML_ASSERT(rs_cell.tail_rc > 0);
rs_cell.tail_rc -= 1;
}
if ((uint32_t) node.seq_id < seq_tails.size()) {
auto & seq = seq_tails[node.seq_id];
bool other_no_longer_shared = rs_cell.seq_nodes.size() == 2;
if (node.is_tail()) {
seq.tail = rs_cell.prev;
if (seq.tail >= 0 && (uint32_t) seq.tail < size) {
llama_rs_cell & new_tail = cells[seq.tail];
new_tail.insert_node(node.seq_id); // ensures next_cell == -1
new_tail.tail_rc += 1;
seq.shared = cells[seq.tail].seq_nodes.size() > 1;
} else {
seq.shared = false;
}
GGML_ASSERT(rs_cell.tail_rc > 0);
rs_cell.tail_rc -= 1;
if (node_iter == rs_cell.seq_nodes.begin()) {
// this seq_id was the first in the list
seq.n_cells -= 1;
if (seq.n_cells == 0) {
n_seqs -= 1;
}
if (node_ptr == rs_cell.seq_nodes.data()) {
// this seq_id was the first in the list
seq.n_cells -= 1;
if (seq.n_cells == 0) {
n_seqs -= 1;
}
// the next node is the new first one, so update its n_cells
// (will never be out-of-bounds because the size is > 1)
llama_rs_seq_node next_node = node_ptr[1];
if ((uint32_t) next_node.seq_id < seq_tails.size()) {
auto & next_seq = seq_tails[next_node.seq_id];
next_seq.n_cells += 1;
if (next_seq.n_cells == 1) {
n_seqs += 1;
}
if (other_no_longer_shared) {
next_seq.shared = false;
}
} else {
GGML_ASSERT(false && "invalid seq_id");
// the next node is the new first one, so update its n_cells
// (will never be out-of-bounds because the size is > 1)
llama_rs_seq_node next_node = *(std::next(node_iter));
if ((uint32_t) next_node.seq_id < seq_tails.size()) {
auto & next_seq = seq_tails[next_node.seq_id];
next_seq.n_cells += 1;
if (next_seq.n_cells == 1) {
n_seqs += 1;
}
} else if (other_no_longer_shared) {
llama_rs_seq_node first_node = rs_cell.seq_nodes[0];
if ((uint32_t) first_node.seq_id < seq_tails.size()) {
seq_tails[first_node.seq_id].shared = false;
} else {
GGML_ASSERT(false && "invalid seq_id");
if (other_no_longer_shared) {
next_seq.shared = false;
}
} else {
GGML_ASSERT(false && "invalid seq_id");
}
} else if (other_no_longer_shared) {
llama_rs_seq_node first_node = rs_cell.seq_nodes[0];
if ((uint32_t) first_node.seq_id < seq_tails.size()) {
seq_tails[first_node.seq_id].shared = false;
} else {
GGML_ASSERT(false && "invalid seq_id");
}
} else {
GGML_ASSERT(false && "invalid seq_id");
}
const bool removed = rs_cell.remove_node(node_ptr);
GGML_ASSERT(removed);
} else {
GGML_ASSERT(false && "invalid seq_id");
}
rs_cell.seq_nodes.erase(node_iter);
}
}
return false;
Expand All @@ -2215,8 +2184,8 @@ struct llama_rs_cache {
if (prev >= 0 && (uint32_t) prev < size) {
// the targeted cell has a previous cell
llama_rs_cell & prev_cell = cells[prev];
llama_rs_seq_node * prev_node = prev_cell.get_node(id);
GGML_ASSERT(prev_node != nullptr); // TODO: recursive insert instead of failing
auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), id);
GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); // TODO: recursive insert instead of failing
GGML_ASSERT(prev_node->next_cell == -1); // or else a chain is broken
if (rs_cell.pos < 0) {
GGML_ASSERT(rs_cell.is_empty());
Expand Down Expand Up @@ -2267,7 +2236,7 @@ struct llama_rs_cache {
int32_t n_system_seqs = 0;
int32_t n_system_cells = 0;
for (size_t i = 0; i < seq_tails.size(); ++i) {
auto & seq = seq_tails[i];
const auto & seq = seq_tails[i];
if (seq.tail >= 0 && (size_t) seq.tail < size) {
if (seq.shared && seq.n_cells > 0) {
n_system_seqs += 1;
Expand Down
0