10000 kv-cache : simplify the interface (#13660) · ochafik/llama.cpp@797f2ac · GitHub
[go: up one dir, main page]

Skip to content

Commit 797f2ac

Browse files
authored
kv-cache : simplify the interface (ggml-org#13660)
* kv-cache : simplify the interface ggml-ci * context : revert llama_batch_allocr position change ggml-ci
1 parent b44890d commit 797f2ac

File tree

9 files changed

+89
-153
lines changed
  • tools
  • 9 files changed

    +89
    -153
    lines changed

    examples/simple-chat/simple-chat.cpp

    Lines changed: 2 additions & 2 deletions
    Original file line numberDiff line numberDiff line change
    @@ -98,7 +98,7 @@ int main(int argc, char ** argv) {
    9898
    auto generate = [&](const std::string & prompt) {
    9999
    std::string response;
    100100

    101-
    const bool is_first = llama_kv_self_used_cells(ctx) == 0;
    101+
    const bool is_first = llama_kv_self_seq_pos_max(ctx, 0) == 0;
    102102

    103103
    // tokenize the prompt
    104104
    const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
    @@ -113,7 +113,7 @@ int main(int argc, char ** argv) {
    113113
    while (true) {
    114114
    // check if we have enough space in the context to evaluate this batch
    115115
    int n_ctx = llama_n_ctx(ctx);
    116-
    int n_ctx_used = llama_kv_self_used_cells(ctx);
    116+
    int n_ctx_used = llama_kv_self_seq_pos_max(ctx, 0);
    117117
    if (n_ctx_used + batch.n_tokens > n_ctx) {
    118118
    printf("\033[0m\n");
    119119
    fprintf(stderr, "context size exceeded\n");

    include/llama.h

    Lines changed: 4 additions & 2 deletions
    Original file line numberDiff line numberDiff line change
    @@ -610,10 +610,12 @@ extern "C" {
    610610

    611611
    // Returns the number of tokens in the KV cache (slow, use only for debug)
    612612
    // If a KV cell has multiple sequences assigned to it, it will be counted multiple times
    613-
    LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx);
    613+
    DEPRECATED(LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx),
    614+
    "Use llama_kv_self_seq_pos_max() instead");
    614615

    615616
    // Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
    616-
    LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx);
    617+
    DEPRECATED(LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx),
    618+
    "Use llama_kv_self_seq_pos_max() instead");
    617619

    618620
    // Clear the KV cache - both cell info is erased and KV data is zeroed
    619621
    LLAMA_API void llama_kv_self_clear(

    src/llama-batch.cpp

    Lines changed: 3 additions & 1 deletion
    Original file line numberDiff line numberDiff line change
    @@ -1,5 +1,6 @@
    11
    #include "llama-batch.h"
    22

    3+
    #include <cassert>
    34
    #include <cstring>
    45
    #include <algorithm>
    56

    @@ -281,9 +282,10 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
    281282
    batch = in_batch;
    282283
    GGML_ASSERT(batch.n_tokens > 0);
    283284
    if (!batch.pos) {
    285+
    assert(p0 >= 0);
    284286
    pos.resize(batch.n_tokens);
    285287
    for (int32_t i = 0; i < batch.n_tokens; i++) {
    286-
    pos[i] = i + p0;
    288+
    pos[i] = p0 + i;
    287289
    }
    288290
    batch.pos = pos.data();
    289291
    }

    src/llama-context.cpp

    Lines changed: 35 additions & 4 deletions
    Original file line numberDiff line numberDiff line change
    @@ -857,11 +857,17 @@ int llama_context::decode(llama_batch & inp_batch) {
    857857
    return -1;
    858858
    }
    859859

    860+
    if (!inp_batch.pos) {
    861+
    if (inp_batch.seq_id) {
    862+
    LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
    863+
    return -1;
    864+
    }
    865+
    }
    866+
    860867
    llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
    861868

    862869
    // temporary allocate memory for the input batch if needed
    863-
    // TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
    864-
    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
    870+
    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1);
    865871

    866872
    const llama_batch & batch = batch_allocr.batch;
    867873

    @@ -2292,22 +2298,47 @@ int32_t llama_apply_adapter_cvec(
    22922298
    // kv cache
    22932299
    //
    22942300

    2301+
    // deprecated
    22952302
    int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
    22962303
    const auto * kv = ctx->get_kv_self();
    22972304
    if (!kv) {
    22982305
    return 0;
    22992306
    }
    23002307

    2301-
    return kv->get_n_tokens();
    2308+
    int32_t res = 0;
    2309+
    2310+
    for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
    2311+
    const llama_pos p0 = kv->seq_pos_min(s);
    2312+
    const llama_pos p1 = kv->seq_pos_max(s);
    2313+
    2314+
    if (p0 >= 0) {
    2315+
    res += (p1 - p0) + 1;
    2316+
    }
    2317+
    }
    2318+
    2319+
    return res;
    23022320
    }
    23032321

    2322+
    // deprecated
    2323+
    // note: this is the same as above - will be removed anyway, so it's ok
    23042324
    int32_t llama_kv_self_used_cells(const llama_context * ctx) {
    23052325
    const auto * kv = ctx->get_kv_self();
    23062326
    if (!kv) {
    23072327
    return 0;
    23082328
    }
    23092329

    2310-
    return kv->get_used_cells();
    2330+
    int32_t res = 0;
    2331+
    2332+
    for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
    2333+
    const llama_pos p0 = kv->seq_pos_min(s);
    2334+
    const llama_pos p1 = kv->seq_pos_max(s);
    2335+
    2336+
    if (p0 >= 0) {
    2337+
    res += (p1 - p0) + 1;
    2338+
    }
    2339+
    }
    2340+
    2341+
    return res;
    23112342
    }
    23122343

    23132344
    void llama_kv_self_clear(llama_context * ctx) {

    src/llama-kv-cache.cpp

    Lines changed: 24 additions & 87 deletions
    Original file line numberDiff line numberDiff line change
    @@ -30,13 +30,14 @@ llama_kv_cache_unified::llama_kv_cache_unified(
    3030
    bool v_trans,
    3131
    bool offload,
    3232
    uint32_t kv_size,
    33-
    uint32_t padding,
    33+
    uint32_t n_seq_max,
    34+
    uint32_t n_pad,
    3435
    uint32_t n_swa,
    35-
    llama_swa_type swa_type) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding), n_swa(n_swa), swa_type(swa_type) {
    36-
    GGML_ASSERT(kv_size % padding == 0 && "kv_size must be a multiple of padding");
    36+
    llama_swa_type swa_type) :
    37+
    model(model), hparams(model.hparams), v_trans(v_trans),
    38+
    n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
    3739

    38-
    this->type_k = type_k;
    39-
    this->type_v = type_v;
    40+
    GGML_ASSERT(kv_size % n_pad == 0);
    4041

    4142
    // create a context for each buffer type
    4243
    std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
    @@ -129,8 +130,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
    129130
    const size_t memory_size_k = size_k_bytes();
    130131
    const size_t memory_size_v = size_v_bytes();
    131132

    132-
    LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6d cells, %3d layers), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
    133-
    (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(),
    133+
    LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
    134+
    (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max,
    134135
    ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
    135136
    ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
    136137
    }
    @@ -442,7 +443,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
    442443
    void llama_kv_cache_unified::defrag_sched(float thold) {
    443444
    // - do not defrag small contexts (i.e. < 2048 tokens)
    444445
    // - count the padding towards the number of used tokens
    445-
    const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + padding)/n)) : 0.0f;
    446+
    const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + n_pad)/n)) : 0.0f;
    446447

    447448
    // queue defragmentation for next llama_kv_cache_update
    448449
    if (fragmentation > thold) {
    @@ -558,7 +559,7 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
    558559
    // a heuristic, to avoid attending the full cache if it is not yet utilized
    559560
    // after enough generations, the benefit from this heuristic disappears
    560561
    // if we start defragmenting the cache, the benefit from this will be more important
    561-
    n = std::min(size, std::max(padding, GGML_PAD(cell_max(), padding)));
    562+
    n = std::min(size, std::max(n_pad, GGML_PAD(cell_max(), n_pad)));
    562563

    563564
    #ifdef FIND_SLOT_DEBUG
    564565
    LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
    @@ -567,20 +568,6 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
    567568
    return true;
    568569
    }
    569570

    570-
    int32_t llama_kv_cache_unified::get_n_tokens() const {
    571-
    int32_t result = 0;
    572-
    573-
    for (uint32_t i = 0; i < size; i++) {
    574-
    result += cells[i].seq_id.size();
    575-
    }
    576-
    577-
    return result;
    578-
    }
    579-
    580-
    int32_t llama_kv_cache_unified::get_used_cells() const {
    581-
    return used;
    582-
    }
    583-
    584571
    bool llama_kv_cache_unified::get_can_shift() const {
    585572
    return true;
    586573
    }
    @@ -802,16 +789,6 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
    802789
    }
    803790
    }
    804791

    805-
    llama_pos llama_kv_cache_unified::get_pos_max() const {
    806-
    llama_pos pos_max = -1;
    807-
    808-
    for (const auto & cell : cells) {
    809-
    pos_max = std::max(pos_max, cell.pos);
    810-
    }
    811-
    812-
    return pos_max;
    813-
    }
    814-
    815792
    size_t llama_kv_cache_unified::total_size() const {
    816793
    size_t size = 0;
    817794

    @@ -1501,11 +1478,8 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
    15011478
    llama_seq_id seq_id;
    15021479
    io.read_to(&seq_id, sizeof(seq_id));
    15031480

    1504-
    // TODO: llama_kv_cache_unified should have a notion of max sequences
    1505-
    //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
    1506-
    if (seq_id < 0) {
    1507-
    //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
    1508-
    LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id);
    1481+
    if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
    1482+
    LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
    15091483
    return false;
    15101484
    }
    15111485

    @@ -1655,17 +1629,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
    16551629
    ggml_type type_v,
    16561630
    bool v_trans,
    16571631
    bool offload,
    1658-
    uint32_t kv_size,
    16591632
    bool swa_full,
    1633+
    uint32_t kv_size,
    16601634
    uint32_t n_seq_max,
    16611635
    uint32_t n_batch,
    1662-
    uint32_t padding) : hparams(model.hparams) {
    1636+
    uint32_t n_pad) : hparams(model.hparams) {
    16631637
    llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
    16641638
    llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
    16651639

    16661640
    const uint32_t size_base = kv_size;
    16671641

    1668-
    uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, padding));
    1642+
    uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, n_pad));
    16691643

    16701644
    // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size and disable pruning
    16711645
    if (swa_full) {
    @@ -1680,14 +1654,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
    16801654

    16811655
    kv_base = std::make_unique<llama_kv_cache_unified>(
    16821656
    model, std::move(filter_base), type_k, type_v,
    1683-
    v_trans, offload, size_base, padding,
    1657+
    v_trans, offload, size_base, n_seq_max, n_pad,
    16841658
    0, LLAMA_SWA_TYPE_NONE);
    16851659

    16861660
    LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
    16871661

    16881662
    kv_swa = std::make_unique<llama_kv_cache_unified>(
    16891663
    model, std::move(filter_swa), type_k, type_v,
    1690-
    v_trans, offload, size_swa, padding,
    1664+
    v_trans, offload, size_swa, n_seq_max, n_pad,
    16911665
    hparams.n_swa, hparams.swa_type);
    16921666
    }
    16931667

    @@ -1810,18 +1784,6 @@ bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) {
    18101784
    return res;
    18111785
    }
    18121786

    1813-
    int32_t llama_kv_cache_unified_iswa::get_n_tokens() const {
    1814-
    return kv_base->get_n_tokens();
    1815-
    }
    1816-
    1817-
    int32_t llama_kv_cache_unified_iswa::get_used_cells() const {
    1818-
    return kv_base->get_used_cells();
    1819-
    }
    1820-
    1821-
    llama_pos llama_kv_cache_unified_iswa::get_pos_max() const {
    1822-
    return kv_base->get_pos_max();
    1823-
    }
    1824-
    18251787
    bool llama_kv_cache_unified_iswa::get_can_shift() const {
    18261788
    return kv_base->get_size() == kv_swa->get_size();
    18271789
    }
    @@ -1853,19 +1815,17 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
    18531815
    ggml_type type_k,
    18541816
    ggml_type type_v,
    18551817
    bool offload,
    1856-
    uint32_t kv_size) : hparams(model.hparams) {
    1818+
    uint32_t kv_size,
    1819+
    uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
    18571820
    const int32_t n_layer = hparams.n_layer;
    18581821

    1859-
    LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d\n",
    1860-
    __func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
    1822+
    LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n",
    1823+
    __func__, kv_size, n_seq_max, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
    18611824

    18621825
    head = 0;
    18631826
    size = kv_size;
    18641827
    used = 0;
    18651828

    1866-
    this->type_k = type_k;
    1867-
    this->type_v = type_v;
    1868-
    18691829
    cells.clear();
    18701830
    cells.resize(kv_size);
    18711831

    @@ -2203,8 +2163,8 @@ void llama_kv_cache_recurrent::commit() {
    22032163
    pending.ranges.clear();
    22042164
    }
    22052165

    2206-
    bool llama_kv_cache_recurrent::update(llama_context & lctx) {
    2207-
    GGML_UNUSED(lctx);
    2166+
    bool llama_kv_cache_recurrent::update(llama_context & ctx) {
    2167+
    GGML_UNUSED(ctx);
    22082168
    return false;
    22092169
    }
    22102170

    @@ -2265,7 +2225,7 @@ bool llama_kv_cache_recurrent::find_slot(
    22652225
    if (seq_id < 0 || (uint32_t) seq_id >= size) {
    22662226
    // too big seq_id
    22672227
    // TODO: would it be possible to resize the cache instead?
    2268-
    LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size);
    2228+
    LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%u Try using a bigger --parallel value\n", __func__, seq_id, n_seq_max);
    22692229
    return false;
    22702230
    }
    22712231
    if (j > 0) {
    @@ -2408,29 +2368,6 @@ bool llama_kv_cache_recurrent::find_slot(
    24082368
    return n >= n_seqs;
    24092369
    }
    24102370

    2411-
    int32_t llama_kv_cache_recurrent::get_n_tokens() const {
    2412-
    int32_t result = 0;
    2413-
    2414-
    for (uint32_t i = 0; i < size; i++) {
    2415-
    result += 325D cells[i].seq_id.size();
    2416-
    }
    2417-
    2418-
    return result;
    2419-
    }
    2420-
    2421-
    int32_t llama_kv_cache_recurrent::get_used_cells() const {
    2422-
    return used;
    2423-
    }
    2424-
    2425-
    llama_pos llama_kv_cache_recurrent::get_pos_max() const {
    2426-
    llama_pos pos_max = -1;
    2427-
    for (const auto & cell : cells) {
    2428-
    pos_max = std::max(pos_max, cell.pos);
    2429-
    }
    2430-
    2431-
    return pos_max;
    2432-
    }
    2433-
    24342371
    bool llama_kv_cache_recurrent::get_can_shift() const {
    24352372
    return false;
    24362373
    }

    0 commit comments

    Comments
     (0)
    0