8000 llama : allow custom list of swa_layers by ngxson · Pull Request #13726 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

llama : allow custom list of swa_layers #13726

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion src/llama-hparams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,26 @@

#include "ggml.h"

llama_hparams::llama_hparams() {
swa_layers.fill(false);
}

void llama_hparams::set_swa_pattern(uint32_t n_pattern) {
for (uint32_t il = 0; il < n_layer; ++il) {
swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
}
}

bool llama_hparams::is_swa_any() const {
for (uint32_t il = 0; il < n_layer; ++il) {
if (swa_layers[il]) {
return true;
}
}

return false;
}

uint32_t llama_hparams::n_head(uint32_t il) const {
if (il < n_layer) {
return n_head_arr[il];
Expand Down Expand Up @@ -72,7 +92,7 @@ uint32_t llama_hparams::n_embd_v_s() const {

bool llama_hparams::is_swa(uint32_t il) const {
if (il < n_layer) {
return n_swa_pattern == 0 || (il % n_swa_pattern < (n_swa_pattern - 1));
return swa_layers[il];
}

GGML_ABORT("fatal error");
Expand Down
39 changes: 25 additions & 14 deletions src/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,20 +102,12 @@ struct llama_hparams {

// Sliding Window Attention (SWA)
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;

uint32_t n_swa = 0; // the size of the sliding window (0 - no SWA)
uint32_t n_swa_pattern = 1; // this value n means that every nth layer is dense (i.e. non-SWA)
// by default n == 1, all layers are dense
// note that if n_swa_pattern == 0, all layers are SWA
// example: n_swa_pattern = 3
// il == 0: swa
// il == 1: swa
// il == 2: dense
// il == 3: swa
// il == 4: swa
// il == 5: dense
// il == 6: swa
// etc ...
// the size of the sliding window (0 - no SWA)
uint32_t n_swa = 0;
// if swa_layers[il] == true, then layer il is SWA
// if swa_layers[il] == false, then layer il is dense (i.e. non-SWA)
// by default, all layers are dense
std::array<bool, LLAMA_MAX_LAYERS> swa_layers;

// for State Space Models
uint32_t ssm_d_conv = 0;
Expand Down Expand Up @@ -153,6 +145,25 @@ struct llama_hparams {
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;

llama_hparams();

// this value n_pattern means that every nth layer is dense (i.e. non-SWA)
// note that if n_pattern == 0, all layers are SWA
// if n_pattern == 1, all layers are dense
// example: n_pattern = 3
// il == 0: swa
// il == 1: swa
// il == 2: dense
// il == 3: swa
// il == 4: swa
// il == 5: dense
// il == 6: swa
// etc ...
void set_swa_pattern(uint32_t n_pattern);

// return true if one of the layers is SWA
bool is_swa_any() const;

uint32_t n_head(uint32_t il = 0) const;

uint32_t n_head_kv(uint32_t il = 0) const;
Expand Down
16 changes: 8 additions & 8 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {

hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED;
hparams.n_swa = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
hparams.n_swa_pattern = 4; // pattern: 3 chunked - 1 full
hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full

switch (hparams.n_expert) {
case 16: type = LLM_TYPE_17B_16E; break;
Expand Down Expand Up @@ -863,7 +863,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
hparams.swa_type = LLAMA_SWA_TYPE_NONE;

hparams.n_swa = 0;
hparams.n_swa_pattern = 1;
hparams.set_swa_pattern(1);
}
} break;
case LLM_ARCH_PHIMOE:
Expand Down Expand Up @@ -935,7 +935,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
{
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.n_swa = 4096; // default value of gemma 2
hparams.n_swa_pattern = 2;
hparams.set_swa_pattern(2);
hparams.attn_soft_cap = true;

ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
Expand All @@ -953,7 +953,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
case LLM_ARCH_GEMMA3:
{
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.n_swa_pattern = 6;
hparams.set_swa_pattern(6);

hparams.rope_freq_base_train_swa = 10000.0f;
hparams.rope_freq_scale_train_swa = 1.0f;
Expand Down Expand Up @@ -1038,7 +1038,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
case LLM_ARCH_COHERE2:
{
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.n_swa_pattern = 4;
hparams.set_swa_pattern(4);

ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
Expand Down Expand Up @@ -4320,7 +4320,7 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot);
LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
LLAMA_LOG_INFO("%s: n_swa_pattern = %u\n", __func__, hparams.n_swa_pattern);
LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any());
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k);
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v);
LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str());
Expand Down Expand Up @@ -13216,7 +13216,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);

if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
GGML_ASSERT(hparams.n_swa_pattern != 1);
GGML_ASSERT(hparams.is_swa_any());

res = new llama_kv_cache_unified_iswa(
*this,
Expand All @@ -13230,7 +13230,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
cparams.n_batch,
padding);
} else {
GGML_ASSERT(hparams.n_swa_pattern == 1);
GGML_ASSERT(!hparams.is_swa_any());

res = new llama_kv_cache_unified(
*this,
Expand Down
Loading
0