8000 llama : support Jamba hybrid Transformer-Mamba models by compilade · Pull Request #7531 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

llama : support Jamba hybrid Transformer-Mamba models #7531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 61 commits into from
Jul 9, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
271104c
wip: llama : separate recurrent states from the KV cache
compilade Apr 3, 2024
8db1e4d
llama : use std::find for seq_nodes in llama_rs_cache
compilade Apr 4, 2024
0028010
llama : state checkpoints for recurrent models
compilade Apr 8, 2024
0c8b3b2
llama : correctly handle more edge cases for the rs cache
compilade Apr 9, 2024
d66849f
Merge branch 'master' into compilade/refactor-kv-cache
compilade Apr 10, 2024
a09db95
llama : rename many llama_kv_cache_* functions
compilade Apr 29, 2024
c460ff1
Merge branch 'master' into compilade/refactor-kv-cache
compilade Apr 29, 2024
b6fafd1
llama : remove useless return value for some llama_cache_* functions
compilade Apr 29, 2024
b7ec12e
Merge branch 'master' into compilade/refactor-kv-cache
compilade May 12, 2024
3b57b55
Merge branch 'master' into compilade/refactor-kv-cache
compilade May 22, 2024
7e13f19
llama : rethink recurrent state cell counts
compilade May 24, 2024
cbc743e
llama : support Jamba
compilade May 24, 2024
0fd13e9
Merge branch 'master' into compilade/refactor-kv-cache
compilade May 24, 2024
61a88a1
llama : fix BERT inference without KV cache
compilade May 25, 2024
ea2e63e
convert-hf : check for unprocessed Jamba experts
compilade May 25, 2024
fc59407
convert-hf : support Mini-Jamba conversion
compilade May 25, 2024
181dadf
llama : fix Jamba quantization sanity checks
compilade May 28, 2024
3a414b0
llama : sequence-length-aware batch splitting
compilade May 28, 2024
4e4c41e
Merge branch 'master' into compilade/refactor-kv-cache
compilade May 28, 2024
3587a94
llama : use equal-sequence-length sub-batches for recurrent models
compilade Jun 1, 2024
5d3c7b9
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jun 1, 2024
72eea49
llama : fix batch split output count for embeddings
compilade Jun 1, 2024
18d1c14
llama : minimize swaps when reordering logits
compilade Jun 1, 2024
61200ef
llama : fix edge case finding batch seq_id of split recurrent cell
compilade Jun 1, 2024
eb589d5
llama : avoid copies for simple batch splits
compilade Jun 2, 2024
8fb57ac
llama : use im2col and mul_mat to perform convolution for Mamba
compilade Jun 3, 2024
17f6c1e
llama : fix .base() compilation error on Windows
compilade Jun 3, 2024
fee3c1d
llama : allow doing the equivalent of SSM_CONV with SUM_ROWS and MUL
compilade Jun 3, 2024
6840ac0
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jun 8, 2024
372482d
llama : rename llama_cache to llama_past
compilade Jun 8, 2024
43d8d4b
examples : replace llama_kv_cache_seq_* with llama_past_seq_*
compilade Jun 10, 2024
ff794f5
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jun 12, 2024
33425a7
mamba : fix non-contiguous usage of ggml_silu
compilade Jun 12, 2024
10c3c41
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jun 30, 2024
9b38f8b
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jul 4, 2024
bc320ef
Merge branch 'master' into compilade/refactor-kv-cache
compilade Sep 1, 2024
fcb889c
llama : session saving and reloading for hybrid models
compilade Sep 2, 2024
a03e32a
Merge branch 'master' into compilade/refactor-kv-cache
compilade Sep 2, 2024
9d3f44d
convert_hf : fix Jamba conversion
compilade Sep 2, 2024
5f62db7
llama : fix mixed signedness comparison
compilade Sep 2, 2024
375de5b
llama : use unused n_embd_k_gqa in k_shift
compilade Sep 2, 2024
4bb4b22
llama : begin renaming llama_past back to llama_kv_cache
compilade Sep 14, 2024
63ac36b
Merge branch 'master' into compilade/refactor-kv-cache
compilade Sep 14, 2024
124c222
Merge branch 'master' into compilade/refactor-kv-cache
compilade Oct 12, 2024
8006f3b
llama : remove implicit recurrent state rollbacks
compilade Nov 25, 2024
691698e
Merge branch 'master' into compilade/refactor-kv-cache
compilade Nov 25, 2024
e3fe612
llama : partially apply clang-format style
compilade Nov 25, 2024
2bcaf64
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jul 3, 2025
908e655
convert : fix jamba conv1d shape squeezing
compilade Jul 3, 2025
4682e21
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jul 3, 2025
20f8e43
graph : add back hybrid memory graph input
compilade Jul 3, 2025
07c252f
model : add Jamba to Mamba-specific hparams printing
compilade Jul 3, 2025
f716358
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jul 7, 2025
b0b280e
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jul 8, 2025
db5ff0c
jamba : remove redundant nullptr initializations
compilade Jul 8, 2025
2f39cd7
model : remove unnecessary prefix for tensor loading constants
compilade Jul 8, 2025
f7c7a92
model : use ggml_swiglu_split for Mamba
compilade Jul 8, 2025
a60a24b
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jul 9, 2025
7f3955a
model : make falcon-h1 use shared mamba2 layer builder
compilade Jul 9, 2025
452207f
memory : avoid referring to KV in recurrent cache logs
compilade Jul 9, 2025
4d6a179
gguf-py : avoid adding duplicate tensor mappings for Jamba
compilade Jul 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
llama : remove useless return value for some llama_cache_* functions
  • Loading branch information
compilade committed Apr 29, 2024
commit b6fafd174721c930e89b27df7de6ee776ace9ade
47 changes: 12 additions & 35 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2887,7 +2887,6 @@ static bool llama_cache_init(
bool offload) {
const struct llama_hparams & hparams = model.hparams;


// TODO: per layer n_embd_*
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
Expand Down Expand Up @@ -3010,6 +3009,8 @@ static bool llama_cache_find_slot(
const uint32_t rs_size = cache.rs.size;
const uint32_t n_tokens = batch.n_tokens;

// FIXME: on failure, leave all caches in a consistent state.

if (rs_size > 0) {
// For recurrent state architectures (like Mamba),
// each cache cell can store the state for a whole sequence.
Expand Down Expand Up @@ -3509,7 +3510,7 @@ static void llama_cache_seq_keep(struct llama_cache & cache, llama_seq_id seq_id
}
}

static llama_pos llama_cache_seq_add(
static void llama_cache_seq_add(
struct llama_cache & cache,
llama_seq_id seq_id,
llama_pos p0,
Expand All @@ -3519,8 +3520,6 @@ static llama_pos llama_cache_seq_add(
if (p0 < 0) { p0 = 0; }
if (p1 < 0) { p1 = std::numeric_limits<llama_pos>::max(); }

llama_pos n_past = p0;

if (cache.rs.size > 0) {
// for Mamba-like models, only the pos needs to be shifted
auto & seq = cache.rs.seq_tails[seq_id];
Expand All @@ -3541,9 +3540,6 @@ static llama_pos llama_cache_seq_add(
}
}
}
if (n_past <= rs_cell.pos) {
n_past = rs_cell.pos + 1;
}
}

// If we freed up a slot, set head to it so searching can start there.
Expand Down Expand Up @@ -3573,21 +3569,16 @@ static llama_pos llama_cache_seq_add(
}
}
}
if (n_past <= kv_cell.pos) {
n_past = kv_cell.pos + 1;
}
}
}

// If we freed up a slot, set head to it so searching can start there.
// Otherwise we just start the next search from the beginning.
cache.kv.head = new_head != cache.kv.size ? new_head : 0;
}

return n_past;
}

static llama_pos llama_cache_seq_div(
static void llama_cache_seq_div(
struct llama_cache & cache,
llama_seq_id seq_id,
llama_pos p0,
Expand All @@ -3596,8 +3587,6 @@ static llama_pos llama_cache_seq_div(
if (p0 < 0) { p0 = 0; }
if (p1 < 0) { p1 = std::numeric_limits<llama_pos>::max(); }

llama_pos n_past = p0;

if (cache.rs.size > 0) {
// for Mamba-like models, only the pos needs to be changed
auto & seq = cache.rs.seq_tails[seq_id];
Expand All @@ -3609,9 +3598,6 @@ static llama_pos llama_cache_seq_div(
rs_cell.pos /= d;
}
cell_id = rs_cell.prev;
if (n_past <= rs_cell.pos) {
n_past = rs_cell.pos + 1;
}
}
}

Expand All @@ -3628,14 +3614,9 @@ static llama_pos llama_cache_seq_div(
kv_cell.delta += kv_cell.pos - p_old;
}
}
if (n_past <= kv_cell.pos) {
n_past = kv_cell.pos + 1;
}
}
}
}

return n_past;
}

static llama_pos llama_cache_seq_pos_max(struct llama_cache & cache, llama_seq_id seq_id) {
Expand Down Expand Up @@ -16935,27 +16916,23 @@ void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) {
llama_cache_seq_keep(ctx, seq_id);
}

llama_pos llama_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return 0; }
if (delta == 0) {
return llama_cache_seq_pos_max(ctx->cache, seq_id) + 1;
}
void llama_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; }
if (delta == 0) { return; }

return llama_cache_seq_add(ctx->cache, seq_id, p0, p1, delta);
llama_cache_seq_add(ctx->cache, seq_id, p0, p1, delta);
}

// deprecated
void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
llama_cache_seq_add(ctx, seq_id, p0, p1, delta);
}

llama_pos llama_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return 0; }
if (d == 1) {
return llama_cache_seq_pos_max(ctx->cache, seq_id) + 1;
}
void llama_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; }
if (d == 1) { return; }

return llama_cache_seq_div(ctx->cache, seq_id, p0, p1, d);
llama_cache_seq_div(ctx->cache, seq_id, p0, p1, d);
}

// deprecated
Expand Down
14 changes: 7 additions & 7 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,8 @@ extern "C" {
// seq_id < 0 : match any sequence
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
// Returns n_past
// Returns n_past (one more than the largest remaining pos in the seq_id)
// which is only meaningful to handle for partial removals.
LLAMA_API llama_pos llama_cache_seq_rm(
struct llama_context * ctx,
llama_seq_id seq_id,
Expand All @@ -579,7 +580,8 @@ extern "C" {
// Note that this does not allocate extra KV or RS cache memory - it simply assigns the tokens to the new sequence
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
// Returns n_past
// Returns n_past (one more than the largest remaining pos in the destination seq_id)
// which is only meaningful to handle when partially copying.
LLAMA_API llama_pos llama_cache_seq_cp(
struct llama_context * ctx,
llama_seq_id seq_id_src,
Expand Down Expand Up @@ -609,8 +611,7 @@ extern "C" {
// - explicitly with llama_kv_cache_update()
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
// Returns n_past
LLAMA_API llama_pos llama_cache_seq_add(
LLAMA_API void llama_cache_seq_add(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
Expand All @@ -630,8 +631,7 @@ extern "C" {
// - explicitly with llama_kv_cache_update()
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
// Returns n_past
LLAMA_API llama_pos llama_cache_seq_div(
LLAMA_API void llama_cache_seq_div(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
Expand All @@ -652,7 +652,7 @@ extern "C" {
LLAMA_API DEPRECATED(llama_pos llama_kv_cache_seq_pos_max(
struct llama_context * ctx,
llama_seq_id seq_id),
"use llama_cache_seq_pos_max instead, which also now returns -1 instead of 0 when the seq_id has no cells");
"use llama_cache_seq_pos_max instead, which now returns -1 instead of 0 when the seq_id has no cells");

// Defragment the KV cache
// This will be applied:
Expand Down
0