8000 feat: Hybrid unified/recurrent cache by gabe-l-hart · Pull Request #13276 · ggml-org/llama.cpp · GitHub
[go: up one dir, main page]

Skip to content

feat: Hybrid unified/recurrent cache #13276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
0e9f0e0
tests: Initial unit tests for memory hierarchy
gabe-l-hart May 20, 2025
b656613
build: Add build step for test-memory on non-windows builds
gabe-l-hart May 20, 2025
9a15f27
fix(tests): Fix constructors in tests for signature changes after rebase
gabe-l-hart May 27, 2025
d74d76a
tests(wip): More robust test for unified cache
gabe-l-hart May 23, 2025
114f2ce
feat: Add can_seq_rm API to llama_kv_cache API
gabe-l-hart May 27, 2025
13efcf3
feat: Move layer_filter_cb up to llama_kv_cache
gabe-l-hart May 20, 2025
6625248
feat: Add layer filter to recurrent cache
gabe-l-hart May 20, 2025
76771e8
feat: Initial implementation of llama_kv_cache_hybrid
gabe-l-hart May 20, 2025
f031fb8
feat: Add llama_model_is_hybrid API call
gabe-l-hart May 9, 2025
298e147
feat: Add c++ side constants for attention layer indices hparam
gabe-l-hart May 9, 2025
f472373
feat: Add support for distinguishing recurrent vs non-recurrent layer…
gabe-l-hart May 9, 2025
404783b
feat: Auto-fill hparams.recurrent_layer_arr based on whether the mode…
gabe-l-hart May 9, 2025
53ef2d4
feat: Instantiate hybrid cache for hybrid models (currently none)
gabe-l-hart May 20, 2025
e6ff93a
refactor: rename *_is_hybrid -> *_is_hybrid_recurrent
gabe-l-hart May 20, 2025
226955b
fix: Fix indexing into k_l for recurrent cache with filter
gabe-l-hart May 20, 2025
1fb08da
fix: Use per-layer sizing everywhere in kv caches
gabe-l-hart May 14, 2025
04fe482
fix: Remove unused kv cache methods after rebase
gabe-l-hart May 23, 2025
db9a618
fix(tests): Fix constructors in tests for signature changes after rebase
gabe-l-hart May 23, 2025
8aee2e7
feat: Add split_equal to init(...) signature
gabe-l-hart May 27, 2025
a4cc4aa
fix: Overhaul hybrid cache for refactor part3 (::init interface)
gabe-l-hart May 27, 2025
58994f6
tests(wip): Comment out broken test for now and fix other constructor…
gabe-l-hart May 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: Add layer filter to recurrent cache
Branch: HybridCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
  • Loading branch information
gabe-l-hart committed May 27, 2025
commit 6625248043bc39c1e8a89c591eeb7a4ac4f5ac7d
18 changes: 12 additions & 6 deletions src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1951,12 +1951,13 @@ class llama_kv_cache_recurrent_decode_state_t : public llama_memory_decode_state
};

llama_kv_cache_recurrent::llama_kv_cache_recurrent(
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool offload,
uint32_t kv_size,
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
const llama_model & model,
layer_filter_cb && filter,
ggml_type type_k,
ggml_type type_v,
bool offload,
uint32_t kv_size,
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
const int32_t n_layer = hparams.n_layer;

LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n",
Expand Down Expand Up @@ -1998,6 +1999,11 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
v_l.reserve(n_layer);

for (int i = 0; i < n_layer; i++) {
if (filter && !filter(i)) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is buggy because push_back is used below, so later when we index directly into the given per-layer tensor vectors, the final layers will be out-of-bounds reads.

LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, i);
continue;
}

const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();

Expand Down
13 changes: 7 additions & 6 deletions src/llama-kv-cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,12 +327,13 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
class llama_kv_cache_recurrent : public llama_kv_cache {
public:
llama_kv_cache_recurrent(
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool offload,
uint32_t kv_size,
uint32_t n_seq_max);
const llama_model & model,
layer_filter_cb && filter,
ggml_type type_k,
ggml_type type_v,
bool offload,
uint32_t kv_size,
uint32_t n_seq_max);

~llama_kv_cache_recurrent() = default;

Expand Down
1 change: 1 addition & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13208,6 +13208,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
{
res = new llama_kv_cache_recurrent(
*this,
nullptr,
GGML_TYPE_F32,
GGML_TYPE_F32,
cparams.offload_kqv,
Expand Down
1 change: 1 addition & 0 deletions tests/test-memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ static void test_llama_kv_cache_recurrent_constructor() {
auto model = _make_model(LLM_ARCH_MAMBA);
llama_kv_cache_recurrent cache(
/* model */ *model,
/* filter */ nullptr,
/* type_k */ GGML_TYPE_F32,
/* type_v */ GGML_TYPE_F16,
/* offload */ false,
Expand Down
0