-
Notifications
You must be signed in to change notification settings - Fork 12k
kv-cache : simplify #13746
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
kv-cache : simplify #13746
Conversation
d23f887
to
8323e23
Compare
c1434b8
to
1eec34a
Compare
This PR should not cause any performance changes and the numerical results should be mostly the same (with some small exceptions due to the new logic in Would appreciate some testing and reports for regressions. Thanks. |
I re-run the ppl test from #13194 (comment) master at aa50ba4
This PR:
Some results changed very slightly, so I'm not sure if this is expect |
Yes, I think this difference is expected for SWA models (note Phi currently is disabled SWA, so no difference). It's caused by the different order in which we place the data in memory, due to the |
Yes that's right, I added
Edit: except for |
This comment was marked as resolved.
This comment was marked as resolved.
I re-run the test and the ppl stays the same as my last comment. Btw, just thinking, is it possible (and it is useful) to add a ppl test mode that uses the KV remove API? |
The ./bin/llama-perplexity -hf bartowski/gemma-2-9b-it-GGUF:Q4_K_M -f ./wikitext-2-raw/wiki.test.raw -c 16384 -fa --chunks 2 --swa-full Maybe your reference value on
Can you clarify? |
I can't run the ppl rn, but if you get correct result, then I think yes could be a problem on my side.
Currently, AFAIU the ppl test simply evaluate text chunk by chunk, but only going forward. For example, if I have 3 chunks: 1-2-3, then they will be evaluated in the order of 1-2-3 But what we also what to test is for example:
So I expect the ppl to be the same as just doing 1-2-3 |
3ef770f
to
0b73da5
Compare
How does this recover from a failed call to |
There are some tricky scenarios in which we could have overwritten some of the data in the cache by the time the error occurs (i.e. processed the first few ubatches, but not all of them yet). Before (i.e. on I think that on compute error, the KV cache should be assumed in an undefined state and the application should take necessary steps to recover (i.e. by clearing it and reprocessing the context that is currently needed). Later on, this reprocessing will become seamless, when we start storing the necessary tokens/embeddings information and add the logic for auto-reprocessing whatever is currently missing from the cache. |
I am mostly concerned about the abort callback functionality. Errors in the backend are likely to be unrecoverable, but I am not sure if the abort functionality makes sense if it leaves the cache in a bad state. |
I admit that I had completely forgotten about the abort callback. Let me see if we can do something about this. |
0b73da5
to
2252eef
Compare
Drafting for now as I want to do some more testing and think about the abort mechanism. |
|
||
std::vector<llama_ubatch> ubatches; | ||
while (sbatch.n_tokens > 0) { | ||
ubatches.push_back(sbatch.split_simple(n_ubatch)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I'm reading this right that it's not possible to use split_equal
with the unified cache after this refactor. If so, I think this will cause problems with the eventual hybrid cache where split_equal
is required for recurrent child caches (cc @compilade since you pointed that out to me earlier).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is possible to split_equal
with the unified cache - it is just not using it here because there is no reason for it. The decision how to split the batch into ubatches is implemented per cache type. If the hybrid cache requires split_equal
then it's init()
method should use that. This code here will never be called for the hybrid cache.
} | ||
std::vector<llama_ubatch> ubatches; | ||
|
||
while (sbatch.n_tokens > 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In trying to update my hybrid implementation to match this PR, I think the logic around creating ubatches
is going to make things tricky. I've attempted to implement it in the abstract where the parent cache doesn't know anything about the child caches and never downcasts them unless asked to by an owner that does know about the child types. The challenge crops up here with the splitting logic since recurrent requires split_equal
and with this change, unified requires split_simple
, so the ubatches themselves are not guaranteed to be identical across child caches.
Reading what you've got here for iswa
, it looks like you're essentially unrolling the logic in llama_kv_cache_unified::init
and performing it pairwise across the two child caches based on the explicit knowledge that they're both llama_kv_cache_unified
types and therefore expose prepare
. I could do something similar in llama_kv_cache_hybrid::init
, but not without a lot of conditional logic and dynamic casting to call the appropriate flavor of batch splitting and prepare
. The two possible solutions I can imagine that would avoid this would be:
- Add an argument to the abstract definition of
init
to allow the caller to specify the split type - Make
prepare
virtual inllama_kv_cache
and update the implementation inllama_kv_cache_recurrent
to also return a vector of heads (though I'm not clear what that would mean for the recurrent cache).
I think my personal preference would be (1) which I may try to do on my branch to see how it works. The other alternative would be to scrap the idea of keeping llama_kv_cache_hybrid
abstract and instead explicitly have it own two child caches, one unified
and one recurrent
. I'd love to avoid this to enable arbitrary future hybrid styles like mixes of swa, unified, recurrent, etc all within one model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Add an argument to the abstract definition of init to allow the caller to specify the split type
- Make prepare virtual in llama_kv_cache and update the implementation in llama_kv_cache_recurrent to also return a vector of heads (though I'm not clear what that would mean for the recurrent cache).
I've considered these options and they don't work for the same reason. The produced decoding state requires to keep cache-specific information, like the head positions and potentially other data for other types of caches. Abstracting this at the llama_kv_cache
level, will simply move the same logic that we currently have to the llama_memory_decode_state
.
Additionally, I think the caller should not know about sbatches. With the current design, we init()
with a generic batch and receive a set of ubatches. The entire logic for how to produce these ubatches is contained inside the KV cache implementation.
The other alternative would be to scrap the idea of keeping llama_kv_cache_hybrid abstract and instead explicitly have it own two child caches, one unified and one recurrent.
This should work.
be635a7
to
7dc61c2
Compare
7dc61c2
to
a3ebf0a
Compare
@slaren With the current proposal, we have the following invariant that should always be true: For each sequence id This remains true even after aborting the processing of a batch. This way after the abort, the user code can query the context about the min/max pos for each sequence and decide which tokens from the input batch weren't processed and take the respective action. To achieve that, with a3ebf0a we now call The same logic is also applied when a compute error occurs, although there is no guarantee that some other state would be in a healthy state after such errors. Let me know if this sounds good. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The solution seems good.
I still think that n
and head
should not be part of the KV state. It leads to very confusing code, for example the new process
function does not receive or do any KV allocation information, so how does it even work? well, it doesn't really work by itself, it depends on the state of the object that must be setup before calling this function. Making functions as close as possible to pure functions that do not depend on any external state greatly improves the code readability.
I tried to address this concern in the latest commit. Still need to do the same update for the recurrent cache. Let me know if this seems better now. |
1942427
to
a592c13
Compare
It's definitely clearer now, but the fundamental problem is still the same, |
Yes, I see it now - it might not be too difficult to do. Will try to do that. |
// TODO: improve to accept cells that are masked by the SWA | ||
if (!cells.is_empty(head + i)) { | ||
const llama_pos pos = ubatch.pos[i]; | ||
const llama_seq_id seq_id = ubatch.seq_id[i][0]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In local testing trying to use split_equal
with the unified portion of the hybrid cache, this line is causing problems. Specifically, I think it conflicts with the logic in llama_sbatch::add_seq_to_ubatch
(here) where the ubatch.seq_id
is only populated with a non-null value at position ubatch.n_seqs
. I'll keep digging to see if there's a simple solution, but wanted to flag this in case it's an easy fix on your end.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is probably not the right solution, but this "fixed" the issue on my Granite 4 branch:
const llama_seq_id seq_id = ubatch.seq_id[i][0]; | |
const llama_seq_id seq_id = ubatch.seq_id[i] == nullptr ? ubatch.seq_id[0][0] : ubatch.seq_id[i][0]; | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The intended way to traverse ubatch.seq_id
since the splits were introduced in #8526 is by using ubatch.n_seqs
, not ubatch.n_tokens
. In simple splits, ubatch.n_seqs
is equal to ubatch.n_tokens
. Fixing this loop (and also the one in apply_ubatch
) should make it work properly with equal splits too.
Line 141 in a592c13
ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits |
See also the comments explaining the sizes of the arrays in ubatch
:
Lines 10 to 24 in a592c13
struct llama_ubatch { | |
bool equal_seqs; | |
// TODO: whole_seqs for embeddings? | |
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs) | |
uint32_t n_seq_tokens; // tokens per sequence | |
uint32_t n_seqs; | |
llama_token * token; // [n_tokens] | |
float * embd; // [n_embd, n_tokens] | |
llama_pos * pos; // [n_tokens] | |
int32_t * n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence | |
llama_seq_id ** seq_id; // [n_seqs] // TODO: become llama_seq_id * seq_id; | |
int8_t * output; // [n_tokens] | |
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correction, using ubatch.n_seqs
for traversal only applies to ubatch.n_seq_id
and ubatch.seq_id
(in case anyone here relies on the notifications and missed the edit in my previous comment).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, that's really helpful. Since this loop is indexing into both pos
and seq_id
which in this case have different lengths, I'm not quite following the relationship that should be used to extract seq_id
for the given pos
element. I think if ubatch.n_seqs > 1
here, that would automatically disqualify all of the other logic around reusing full cells?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not quite following the relationship that should be used to extract
seq_id
for the givenpos
element.
Usually, this is relatively simple since the number of tokens per sequence is known in ubatch.n_seq_tokens
(for simple splits, this is always 1
1241
code>). In fact, here it could probably be possible to use
const llama_seq_id seq_id = ubatch.seq_id[i][0]; | |
const llama_seq_id seq_id = ubatch.seq_id[i / ubatch.n_seq_tokens][0]; |
although there is another approach without divisions, but with nested loops and which would change the indexing for ubatch.pos[i]
to ubatch.pos[s * ubatch.n_seq_tokens + j]
where s
is in [0, ubatch.n_seqs)
and j
in [0, ubatch.n_seq_tokens)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can confirm that with this fix, the model safely produces output on my Granite 4 branch (no comment on cache correctness though!)
const llama_seq_id seq_id = ubatch.seq_id[i][0]; | ||
|
||
// can we use this cell? either: | ||
// - the cell is empty | ||
// - the cell is occupied only by one sequence: | ||
// - mask causally, if the sequence is the same as the one we are inserting | ||
// - mask SWA, using current max pos for that sequence in the cache | ||
// always insert in the cell with minimum pos | ||
bool can_use = cells.is_empty(head_cur + i); | ||
|
||
if (!can_use && cells.seq_count(head_cur + i) == 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that would automatically disqualify all of the other logic around reusing full cells?
Assuming this is correct, I think this would be the correct approach?
const llama_seq_id seq_id = ubatch.seq_id[i][0]; | |
// can we use this cell? either: | |
// - the cell is empty | |
// - the cell is occupied only by one sequence: | |
// - mask causally, if the sequence is the same as the one we are inserting | |
// - mask SWA, using current max pos for that sequence in the cache | |
// always insert in the cell with minimum pos | |
bool can_use = cells.is_empty(head_cur + i); | |
if (!can_use && cells.seq_count(head_cur + i) == 1) { | |
// can we use this cell? either: | |
// - the cell is empty | |
// - the cell is occupied only by one sequence: | |
// - mask causally, if the sequence is the same as the one we are inserting | |
// - mask SWA, using current max pos for that sequence in the cache | |
// always insert in the cell with minimum pos | |
bool can_use = cells.is_empty(head_cur + i); | |
if (!can_use && cells.seq_count(head_cur + i) == 1 && ubatch.n_seqs == 1) { | |
const llama_seq_id seq_id = ubatch.seq_id[0][0]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That diff is gross, but it just adds an extra conditional to the outer check that checks whether ubatch.n_seqs == 1
and then always uses ubatch.seq_id[0][0]
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gabe-l-hart It should not be necessary to limit this branch to when ubatch.n_seqs
to 1
. This almost never happens for simple splits anyway, except when n_ubatch
is 1
.
See #13746 (comment).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Race condition! Thanks thanks
cont #13706 (comment), #13194
Main goal here is to simplify the abstract interface of
struct llama_kv_cache
.Overview
Changes to the internal
struct llama_kv_cache
abstract interface:llama_kv_cache::commit()
llama_kv_cache::restore()
llama_kv_cache::sbatch_init()
llama_kv_cache::ubatch_next()
llama_kv_cache::find_slot()
llama_kv_cache_guard
This new interface changes the logic in
llama_decode()
to first make sure that we can fit the input batch into the cache and only after that we start to process the ubatches. This check takes correctly into account SWA masking and also makes sure that the cache will not be modified before we start the actual computation.note: the latter is not yet true for the recurrent cache - see comments in the code
Another important update in this PR is that the
find_slot()
logic for unified caches is now improved. Before we looked for a slot (i.e. a set of contiguous cells) that is empty in order to place the ubatch in it. We now allow the slot to contain data from the same or other sequence which is masked (either by causality or by SWA):llama.cpp/src/llama-kv-cache.cpp
Lines 574 to 621 in 2252eef
This change is needed for the next PR, which will optimize the SWA cache to use just
n_swa + n_ubatch
cells and it also has some other nice properties. For example, we no longer have to explicitly prune tokens on successful batch processing, which simplifies the logic significantly and allows us to re-enable speculative decoding for SWA models (will be done also in the next PR).The worst-graph reserve logic is also refactored and simplified significantly.
There are also some changes to
llama-batch
, but these are mainly to patch things up so that we are able to push the KV cache refactor first. So no need to review thellama-batch
in deep details - the code there will be reworked soon.With this refactor, I think the
struct llama_kv_cache
interface is getting close to finalized. I still don't like thellama_kv_cache::set_full()
mechanism and will try to find a way to avoid it. I am also hesitating if thellama_kv_cache::update(llama_context)
call is really necessary - it could probably be absorbed in thellama_kv_cache::init()
call, but then the logic there might get too overloaded, so not sure.TODO
Next PRs
llama_decode
, so that user code does not have to do it (llama : auto-batch #13845)n_swa + n_ubatch
for SWA cache (llama : use n_swa + n_ubatch cells for SWA cache #13833)