-
Notifications
You must be signed in to change notification settings - Fork 0
Description
Hello,
I am currently studying attention patterns and long context optimizations like sparse and linear attention for transformers in particular. Just a few weeks ago I read this paper about reusing attention weights (and scores) between layers - the results were quite decent, they also delivered empiric evidence, that attention maps between layers (particularly for successive layers) can be very similar.
In quote from section 3.1:
[...] we compute the Jensen-Shannon (JS) divergence to measure how the attention weight distribution of a layer is different from another [Lin, 1991]. We choose the JS divergence because it is symmetric and bounded. For multi-head attention, we regard different heads as separate channels. We compute the JS score for each individual head and then average them for final output. Figure 2 shows that the system generates similar weights over layers.
After reading it, I wondered, if this idea could be implemented further and be used to improve static sparse attention by reusing the pattern from one layer to calculate a mask for the next layer instead of fully sharing the weights. This is how I stumbled upon your repository and paper. I like the way you implemented this and the way how you first amplify the map is quite elaborate. Also the way you used the map was surprising and very interesting to me. Have you tried / considered to also use the extracted map to save computation in the following layers in addition to your approach with extrapolation?
Additionally I have a small remark. You wrote in your github:
A'_ℓ,h = (Q_ℓ,h K_ℓ,h^T / √d_k) ⊙ M̃_ℓ,h
and in your paper:
This application occurs before the softmax operation within the attention mechanism, effectively limiting computations to the unmasked connections.
to me this reads a little confusing. At first I thought you are ONLY masking the attention logits with -inf AFTER you calculated them - I was confused on how this saves so much memory. But looking at your code, it does seem like you avoid multiplying parts of Q and K altogether by using efficient sparse multiplication, which indeed saves a lot more memory and compute, but it was not very clear to me. Also since this is using triton jit instead of native cuda, there might be even more throughput improvements possible in the future?
Best Regards
Tony