8000 kv-cache : add SWA support by ggerganov · Pull Request #13194 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

kv-cache : add SWA support #13194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
May 20, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
kv-cache : simplify SWA logic
ggml-ci
  • Loading branch information
ggerganov committed May 17, 2025
commit e743246b288f3748df30ba8e9da2783e264c00d6
6 changes: 3 additions & 3 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,17 +362,17 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {

void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) {
kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn, false);
kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
}

void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) {
kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn, false);
kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}

if (self_kq_mask_swa) {
kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn, true);
kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
}
}

Expand Down
7 changes: 4 additions & 3 deletions src/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ enum llama_expert_gating_func_type {
};

enum llama_swa_type {
LLAMA_SWA_TYPE_STANDARD = 0,
LLAMA_SWA_TYPE_CHUNKED = 1,
LLAMA_SWA_TYPE_NONE = 0,
LLAMA_SWA_TYPE_STANDARD = 1,
LLAMA_SWA_TYPE_CHUNKED = 2,
};

struct llama_hparams_posnet {
Expand Down Expand Up @@ -100,7 +101,7 @@ struct llama_hparams {
std::array<int, 4> rope_sections;

// Sliding Window Attention (SWA)
llama_swa_type swa_type = LLAMA_SWA_TYPE_STANDARD;
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;

uint32_t n_swa = 0; // the size of the sliding window (0 - no SWA)
uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
Expand Down
95 changes: 57 additions & 38 deletions src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ llama_kv_cache_unified::llama_kv_cache_unified(
bool v_trans,
bool offload,
uint32_t kv_size,
uint32_t padding) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding) {
uint32_t padding,
uint32_t n_swa,
llama_swa_type swa_type) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding), n_swa(n_swa), swa_type(swa_type) {
GGML_ASSERT(kv_size % padding == 0 && "kv_size must be a multiple of padding");

this->type_k = type_k;
Expand Down Expand Up @@ -594,8 +596,8 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il) cons
// note: v->nb[1] > v->nb[2]
return ggml_view_3d(ctx, v,
n, hparams.n_head_kv(il), hparams.n_embd_head_v,
ggml_element_size(v)*v->ne[1]*hparams.n_embd_head_v, // v->nb[1]
ggml_element_size(v)*v->ne[1], // v->nb[2]
ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1]
ggml_row_size(v->type, v->ne[1]), // v->nb[2]
0);
}

Expand Down Expand Up @@ -640,7 +642,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
return ggml_cpy(ctx, v_cur, v_view);
}

void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn, bool swa) const {
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
const int64_t n_seqs = ubatch->n_seqs;
Expand All @@ -667,41 +669,28 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
const llama_seq_id seq_id = ubatch->seq_id[s][0];

for (int j = 0; j < n_seq_tokens; ++j) {
const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j];

for (int i = 0; i < n_kv; ++i) {
float f;
// mask the token if:
if (!cells[i].has_seq_id(seq_id) // not the correct sequence
|| (causal_attn && cells[i].pos > pos) // for causal, mask future tokens
) {
f = -INFINITY;
} else {
if (hparams.use_alibi) {
f = -std::abs(cells[i].pos - pos);
} else {
f = 0.0f;
}
}
const llama_pos p0 = cells[i].pos;

bool masked = false;

// mask the token if not the same sequence
masked = masked || (!cells[i].has_seq_id(seq_id));

// mask future tokens
masked = masked || (causal_attn && p0 > p1);

if (swa) {
// may need to cut off old tokens for sliding window
switch (hparams.swa_type) {
case LLAMA_SWA_TYPE_STANDARD:
{
if (pos - cells[i].pos >= (int32_t) hparams.n_swa) {
f = -INFINITY;
}
} break;
case LLAMA_SWA_TYPE_CHUNKED:
{
const llama_pos pos_chunk_start = (pos / hparams.n_swa) * hparams.n_swa;

if (cells[i].pos < pos_chunk_start) {
f = -INFINITY;
}
} break;
}
// apply SWA if any
masked = masked || (is_masked_swa(p0, p1));

float f = 0.0f;

if (masked) {
f = -INFINITY;
} else if (hparams.use_alibi) {
f = -std::abs(p0 - p1);
}

data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
Expand Down Expand Up @@ -1191,6 +1180,30 @@ uint32_t llama_kv_cache_unified::cell_max() const {
return 0;
}

bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
switch (swa_type) {
case LLAMA_SWA_TYPE_NONE:
{
} break;
case LLAMA_SWA_TYPE_STANDARD:
{
if (p1 - p0 >= (int32_t) n_swa) {
return true;
}
} break;
case LLAMA_SWA_TYPE_CHUNKED:
{
const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;

if (p0 < pos_chunk_start) {
return true;
}
} break;
}

return false;
}

void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
uint32_t cell_count = 0;
Expand Down Expand Up @@ -1586,11 +1599,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(

LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, kv_size_base);

kv_base = std::make_unique<llama_kv_cache_unified>(model, std::move(filter_base), type_k, type_v, v_trans, offload, kv_size_base, padding);
kv_base = std::make_unique<llama_kv_cache_unified>(
model, std::move(filter_base), type_k, type_v,
v_trans, offload, kv_size_base, padding,
0, LLAMA_SWA_TYPE_NONE);

LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, kv_size_swa);

kv_swa = std::make_unique<llama_kv_cache_unified>(model, std::move(filter_swa), type_k, type_v, v_trans, offload, kv_size_swa, padding);
kv_swa = std::make_unique<llama_kv_cache_unified>(
model, std::move(filter_swa), type_k, type_v,
v_trans, offload, kv_size_swa, padding,
hparams.n_swa, hparams.swa_type);
}

void llama_kv_cache_unified_iswa::clear() {
Expand Down
13 changes: 11 additions & 2 deletions src/llama-kv-cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ class llama_kv_cache_unified : public llama_kv_cache {
bool v_trans,
bool offload,
uint32_t kv_size,
uint32_t padding);
uint32_t padding,
uint32_t n_swa,
llama_swa_type swa_type);

~llama_kv_cache_unified() = default;

Expand Down Expand Up @@ -169,7 +171,7 @@ class llama_kv_cache_unified : public llama_kv_cache {
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;

void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn, bool swa) const;
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
void set_input_k_shift (ggml_tensor * dst) const;
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;

Expand Down Expand Up @@ -223,6 +225,11 @@ class llama_kv_cache_unified : public llama_kv_cache {
ggml_type type_k = GGML_TYPE_F16;
ggml_type type_v = GGML_TYPE_F16;

// SWA
uint32_t n_swa = 0;

llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;

std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;

Expand Down Expand Up @@ -264,6 +271,8 @@ class llama_kv_cache_unified : public llama_kv_cache {
size_t size_k_bytes() const;
size_t size_v_bytes() const;

bool is_masked_swa(llama_pos p0, llama_pos p1) const;

ggml_tensor * build_rope_shift(
const llama_cparams & cparams,
ggml_context * ctx,
Expand Down
29 changes: 21 additions & 8 deletions src/llama-model.cpp
9126
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step);

hparams.swa_type = (llama_swa_type) LLAMA_SWA_TYPE_CHUNKED;
hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED;
hparams.n_swa = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
hparams.n_swa_pattern = 4; // pattern: 3 chunked - 1 full

Expand Down Expand Up @@ -858,18 +858,24 @@ void llama_model::load_hparams(llama_model_loader & ml) {
// default value for Phi-3-mini-4k-instruct and Phi-3-medium-4k-instruct
LLAMA_LOG_WARN("%s: assuming n_swa = 2047 for Phi-3-mini-4k-instruct and Phi-3-medium-4k-instruct\n", __func__);

hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;

hparams.n_swa = 2047;
} else if (hparams.n_layer == 32 && hparams.n_head_kv(0) == 32 && hparams.n_ctx_train == 131072) {
// default value for Phi-3-mini-128k-instruct
LLAMA_LOG_WARN("%s: assuming n_swa = n_ctx_train for Phi-3-mini-128k-instruct\n", __func__);
LLAMA_LOG_WARN("%s: assuming no SWA for Phi-3-mini-128k-instruct\n", __func__);

hparams.swa_type = LLAMA_SWA_TYPE_NONE;

hparams.n_swa = hparams.n_ctx_train;
hparams.n_swa = hparams.n_ctx_train;
hparams.n_swa_pattern = 1;
} else if (hparams.n_layer == 40 && hparams.n_ctx_train == 131072) {
// default value for Phi-3-medium-128k-instruct
LLAMA_LOG_WARN("%s: assuming n_swa = n_ctx_train for Phi-3-medium-128k-instruct\n", __func__);
LLAMA_LOG_WARN("%s: assuming no SWA for Phi-3-medium-128k-instruct\n", __func__);

hparams.swa_type = LLAMA_SWA_TYPE_NONE;

hparams.n_swa = hparams.n_ctx_train;
hparams.n_swa = hparams.n_ctx_train;
hparams.n_swa_pattern = 1;
}

Expand All @@ -879,9 +885,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
}

if (hparams.n_swa > hparams.n_ctx_train) {
LLAMA_LOG_WARN("%s: unexpected n_swa: %d >= %d, setting to 0\n", __func__, hparams.n_swa, hparams.n_ctx_train);
LLAMA_LOG_WARN("%s: unexpected n_swa: %d >= %d, disabling SWA\n", __func__, hparams.n_swa, hparams.n_ctx_train);

hparams.n_swa = hparams.n_ctx_train;
hparams.swa_type = LLAMA_SWA_TYPE_NONE;

hparams.n_swa = hparams.n_ctx_train;
hparams.n_swa_pattern = 1;
}
} break;
Expand Down Expand Up @@ -952,6 +960,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
} break;
case LLM_ARCH_GEMMA2:
{
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.n_swa = 4096; // default value of gemma 2
hparams.n_swa_pattern = 2;
hparams.attn_soft_cap = true;
Expand All @@ -970,6 +979,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
} break;
case LLM_ARCH_GEMMA3:
{
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.n_swa_pattern = 6;

hparams.rope_freq_base_train_swa = 10000.0f;
Expand Down Expand Up @@ -1054,6 +1064,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
} break;
case LLM_ARCH_COHERE2:
{
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.n_swa_pattern = 4;

ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
Expand Down Expand Up @@ -13228,7 +13239,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
!cparams.flash_attn,
cparams.offload_kqv,
cparams.n_ctx,
padding);
padding,
hparams.n_swa,
hparams.swa_type);
}
}
}
Expand Down
0