8000 kv-cache : add SWA support by ggerganov · Pull Request #13194 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

kv-cache : add SWA support #13194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
May 20, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
kv-cache : prepare for SWA
ggml-ci
  • Loading branch information
ggerganov committed May 17, 2025
commit c011e4edb059a15dfacea080d48b6a0ad7b5942e
295 changes: 61 additions & 234 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,6 @@
#include <cmath>
#include <cstring>

static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
// TODO move to hparams if a T5 variant appears that uses a different value
const int64_t max_distance = 128;

if (bidirectional) {
n_buckets >>= 1;
}

const int64_t max_exact = n_buckets >> 1;

int32_t relative_position = x - y;
int32_t relative_bucket = 0;

if (bidirectional) {
relative_bucket += (relative_position > 0) * n_buckets;
relative_position = abs(relative_position);
} else {
relative_position = -std::min<int32_t>(relative_position, 0);
}

int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);

return relative_bucket;
}

void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
if (ubatch->token) {
const int64_t n_tokens = ubatch->n_tokens;
Expand Down Expand Up @@ -110,22 +83,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {

void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
if (pos_bucket) {
const int64_t n_tokens = ubatch->n_tokens;

GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing

int32_t * data = (int32_t *) pos_bucket->data;

const int64_t n_kv = kv_self->n;

for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
for (int i = 0; i < n_kv; ++i) {
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self->cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
}
}
}
kv_self->set_input_pos_bucket(pos_bucket, ubatch);
}
}

Expand Down Expand Up @@ -403,99 +361,12 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
}

void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask || self_kq_mask_swa) {
const int64_t n_kv = kv_self->n;
const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
const int64_t n_seqs = ubatch->n_seqs;

float * data = nullptr;
float * data_swa = nullptr;

if (self_kq_mask) {
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
data = (float *) self_kq_mask->data;
}

if (self_kq_mask_swa) {
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
data_swa = (float *) self_kq_mask_swa->data;
}

// Use only the previous KV cells of the correct sequence for each token of the ubatch.
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
// Causal mask:
// xxx-------
// xxxx------
// xxxxx-----
// Non-causal mask:
// xxxxx-----
// xxxxx-----
// xxxxx-----
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
for (int h = 0; h < 1; ++h) {
for (int s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[s][0];

for (int j = 0; j < n_seq_tokens; ++j) {
const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
for (int i = 0; i < n_kv; ++i) {
float f;
// mask the token if:
if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence
|| (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens
) {
f = -INFINITY;
} else {
if (hparams.use_alibi) {
f = -std::abs(kv_self->cells[i].pos - pos);
} else {
f = 0.0f;
}
}

if (data) {
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
}

// may need to cut off old tokens for sliding window
// TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask"
if (data_swa) {
if (hparams.n_attn_chunk) {
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
if (kv_self->cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
f = -INFINITY;
}
} else {
if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
f = -INFINITY;
}
}
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
}
}
}
}

// mask padded tokens
if (data) {
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) {
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
}
}
}
if (self_kq_mask) {
kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}

// mask padded tokens
if (data_swa) {
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) {
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
}
}
}
}
if (self_kq_mask_swa) {
kv_self->set_input_kq_mask_swa(self_kq_mask_swa, ubatch, cparams.causal_attn);
}
}

Expand Down Expand Up @@ -1153,7 +1024,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {

auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);

const auto n_kv = kv_self->n;
const auto n_kv = kv_self->get_n();

auto & cur = inp->pos_bucket;

Expand Down Expand Up @@ -1188,16 +1059,12 @@ ggml_tensor * llm_graph_context::build_attn_mha(
ggml_tensor * kq_b,
ggml_tensor * kq_mask,
ggml_tensor * v_mla,
bool v_trans,
float kq_scale) const {
//const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
//const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);

//const int64_t n_head = hparams.n_head(il);
//const int64_t n_head_kv = hparams.n_head_kv(il);
const bool v_trans = v->nb[1] > v->nb[2];

//const auto & n_embd_head_k = hparams.n_embd_head_k;
//const auto & n_embd_head_v = hparams.n_embd_head_v;
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
v = ggml_permute(ctx0, v, 0, 2, 1, 3);

const auto n_tokens = q->ne[1];
const auto n_head = q->ne[2];
Expand Down Expand Up @@ -1336,17 +1203,11 @@ ggml_tensor * llm_graph_context::build_attn(

const auto & kq_mask = inp->get_kq_mask();

ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
//cb(q, "q", il);

ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
//cb(k, "k", il);

ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
//cb(k, "v", il);

ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
ggml_tensor * q = q_cur;
ggml_tensor * k = k_cur;
ggml_tensor * v = v_cur;

ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
cb(cur, "kqv_out", il);

if (wo) {
Expand All @@ -1369,17 +1230,21 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()

auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);

const auto n_kv = kv_self->n;
{
const auto n_kv = kv_self->get_n();

inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
ggml_set_input(inp->self_kq_mask);

inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}

if (hparams.n_swa_pattern > 1) {
GGML_ASSERT(hparams.n_swa > 0);

const auto n_kv = kv_self->get_n();

inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
ggml_set_input(inp->self_kq_mask_swa);
Expand Down Expand Up @@ -1409,81 +1274,22 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_build_forward_expand(gf, v_cur);

const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
const auto & n_ctx = cparams.n_ctx;

const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);

const auto n_tokens = q_cur->ne[2];

const bool v_trans = !cparams.flash_attn;

// store to KV cache
{
const auto kv_head = kv_self->head;

GGML_ASSERT(kv_self->size == n_ctx);

ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self->k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa)*kv_head);
//cb(k_cache_view, "k_cache_view", il);

// note: storing RoPE-ed version of K in the KV cache
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));

v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);

ggml_tensor * v_cache_view = nullptr;

if (!v_trans) {
v_cache_view = ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head);
} else {
// note: the V cache is transposed when not using flash attention
v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa,
( n_ctx)*ggml_element_size(kv_self->v_l[il]),
(kv_head)*ggml_element_size(kv_self->v_l[il]));

v_cur = ggml_transpose(ctx0, v_cur);
}
//cb(v_cache_view, "v_cache_view", il);

ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il));
}

const bool is_swa = hparams.is_swa(il);

const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();

const auto n_kv = kv_self->n;
ggml_tensor * q = q_cur;
ggml_tensor * k = kv_self->get_k(ctx0, il);
ggml_tensor * v = kv_self->get_v(ctx0, il);

const int64_t n_head_kv = hparams.n_head_kv(il);

const auto & n_embd_head_k = hparams.n_embd_head_k;
const auto & n_embd_head_v = hparams.n_embd_head_v;

ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
//cb(q, "q", il);

ggml_tensor * k =
ggml_view_3d(ctx0, kv_self->k_l[il],
n_embd_head_k, n_kv, n_head_kv,
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
0);
//cb(k, "k", il);

ggml_tensor * v = !v_trans ?
ggml_view_3d(ctx0, kv_self->v_l[il],
n_embd_head_v, n_kv, n_head_kv,
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
0) :
ggml_view_3d(ctx0, kv_self->v_l[il],
n_kv, n_embd_head_v, n_head_kv,
ggml_element_size(kv_self->v_l[il])*n_ctx,
ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
0);

ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
cb(cur, "kqv_out", il);

if (wo) {
Expand Down Expand Up @@ -1534,17 +1340,11 @@ ggml_tensor * llm_graph_context::build_attn(

const auto & kq_mask = inp->get_kq_mask_cross();

ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
//cb(q, "q", il);

ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
//cb(k, "k", il);

ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
//cb(k, "v", il);

ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
ggml_tensor * q = q_cur;
ggml_tensor * k = k_cur;
ggml_tensor * v = v_cur;

ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
cb(cur, "kqv_out", il);

if (wo) {
Expand Down Expand Up @@ -1712,3 +1512,30 @@ void llm_graph_context::build_pooling(

ggml_build_forward_expand(gf, cur);
}

int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
// TODO move to hparams if a T5 variant appears that uses a different value
const int64_t max_distance = 128;

if (bidirectional) {
n_buckets >>= 1;
}

const int64_t max_exact = n_buckets >> 1;

int32_t relative_position = x - y;
int32_t relative_bucket = 0;

if (bidirectional) {
relative_bucket += (relative_position > 0) * n_buckets;
relative_position = abs(relative_position);
} else {
relative_position = -std::min<int32_t>(relative_position, 0);
}

int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);

return relative_bucket;
}
Loading
0