8000 llama : fix KV shift for qwen2vl (#13870) · ggml-org/llama.cpp@763d06e · GitHub
[go: up one dir, main page]

Skip to content

Commit 763d06e

Browse files
authored
llama : fix KV shift for qwen2vl (#13870)
* llama : fix KV shift for qwen2vl * add ref to the PR
1 parent 1096133 commit 763d06e

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

src/llama-graph.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
455455
}
456456

457457
int64_t llm_graph_context::n_pos_per_embd() const {
458-
return arch == LLM_ARCH_QWEN2VL ? 4 : 1;
458+
return hparams.rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
459459
}
460460

461461
void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {

src/llama-kv-cache.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -757,11 +757,19 @@ ggml_tensor * llama_kv_cache_unified::build_rope_shift(
757757
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
758758

759759
const auto & n_rot = hparams.n_rot;
760-
const auto & rope_type = hparams.rope_type;
760+
const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE
761+
// @ngxson : this is a workaround
762+
// for M-RoPE, we want to rotate the whole vector when doing KV shift
763+
// a normal RoPE should work, we just need to use the correct ordering
764+
// ref: https://github.com/ggml-org/llama.cpp/pull/13870
765+
? LLAMA_ROPE_TYPE_NEOX
766+
: hparams.rope_type;
761767

762768
// See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
763769
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
764-
const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
770+
const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2
771+
? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale))
772+
: cparams.yarn_attn_factor;
765773

766774
ggml_tensor * tmp;
767775

0 commit comments

Comments
 (0)
0