In Pursuit of Fast KV-Cached Attention for Apple Neural Engine
Building a memory-friendly KV Cache with static shapes • October 10, 2024No one wants a slow LLM. Most LLMs run on GPUs and most methods to make them fast are tailored specifically to GPUs.
LLMs can also run on Apple Neural Engine (ANE), Apple's efficient ML processor that comes in every new iPhone and Mac. Existing GPU optimizations do not easily translate to the Neural Engine which means you end up leaving speed on the table.
Today we'll unlock some speed by adapting a popular optimization known as KV caching to the Neural Engine.
Attention Crash Course
To follow along it's helpful to understand the basic mechanics of a transformer LLM.
If this is foreign to you, keep reading. Otherwise skip ahead.
An LLM processes a sequence of tokens (word chunks) and predicts what the next token will be. The process is known as a "forward pass", "LLM call" or "prediction". Repeated forward passes build up the words, sentences, and paragraphs of the LLM's response.
This LLM predicts that "Fa" follows "Mi".Attention is a series of matrix multiplications that happens during the forward pass. It helps the model do a good job at predicting the next token.
There are three matrices involved in attention: Q, K and V. They all have the same number of columns which is not particularly interesting today. The number of rows for each is determined by the number of tokens we're processing.
Q has one row for each token where we want the LLM to make a next-token prediction. K has one row for each token we want the LLM to consider during its predictions. V is the same as K.
The simplest case is we input some fixed number of tokens into the LLM: "Do re mi". This is 3 tokens so Q, K, and V will all have the 3 rows.
Q's 3 rows mean the LLM will predict 3 new tokens: what comes after "Do", and "re", and "mi". We already know that "re" comes after "Do", and "mi" comes after "re", so we'll ignore those predictions but the prediction for what comes after "mi" is new so we'll keep that.
The LLM predicts 3 tokens here, but typically we ignore all but the last.K and V's 3 rows mean the LLM will consider all 3 tokens when predicting what comes next. So the LLM will make a prediction for what comes after "mi" based on "Do", and "re", and "mi", and their positions relative to each other.
You usually wouldn't let an old token like "Do" look at new ones like "Re", but it is technically possible.A more interesting case is where the LLM takes a smaller number of tokens as input, and also some K and V matrices that were computed in a prior forward pass. Following a similar example: the input token is now "fa" and we also pass along a partial K and V, each with three rows that correspond to "Do", "re", and "mi".
Q will now have 1 row, from "fa", and the LLM will only predict a new token to follow "fa".
K and V will have not 1 but 4 rows! The 3 for "Do", "re", "mi" that were passed in plus one new row that the LLM generates for "fa". This allows the LLM to make a well-informed prediction since it can still look at all 4 rows of K and V to see what came before "fa". Importantly, it produces exactly the same results as passing all 4 tokens as inputs to the LLM.
V is also made up of 3 re-used rows, just like K.This process of reusing K and V is the KV caching that we want to implement today.
Now that we know how the shape of these matrices corresponds to our input tokens, we can touch on the actual computation for attention.
First we multiply Q by K. We have to transpose K (swap its rows and columns) for the matrix multiplication to work.
Next we take the result of this multiplication (with Q's number of rows and K's number of rows as columns) and apply a function called softmax. This doesn't change the matrix's shape. This matrix is multiplied by V in the second matrix multiplication which gives us a final matrix that has the same shape as Q originally.
This final matrix then proceeds on through the rest of the LLM. There is more to attention, but this should be enough to follow along below. (If not, let me know on Twitter.)
Neural Engine Constraints
It is often convenient to vary the internal workings of an LLM on the fly. The Neural Engine does not allow this.
For a model to run on the ANE it must have input, output, and intermediate matrices that all have static shapes. They cannot change between calls to the model. The computation graph of the model must also be static. This means no conditional branching even if the intermediate tensors have the same shapes.
Both of these constraints can be slightly relaxed in some circumstances but we will stick with the rigid definition for simplicity.
Static Shaped Attention
We need to pick static sizes for the matrices Q, K, and V. The number of columns is predetermined and constant, so we only need to choose the number of rows. Let's start by focusing on just Q and K, the first attention multiplication, for simplicity.
We'll give K 512 rows. This means the LLM can look back at 512 recent tokens (word chunks) at most in order to predict the next token. This is usable and we can scale it up if needed.
Picking a size for Q is more interesting. The size of Q is equal to the number of input tokens. This size determines how many tokens we can add at once to K for future predictions (typically >1) and how many new tokens we want to predict (typically 1).
These correspond to the two stages of KV-cached LLM processing. Pre-fill: when the LLM ingests your prompt and builds up a cache. Generation: when the LLM responds.
Since we are restricted to static sizes we need to pick a Q that works for both pre-fill and generation. This means that a call to an ANE LLM always processes the same number of tokens and always takes the same amount of time, regardless of processing stage.
If we pick a small size for Q, generation will be fast but pre-fill will be slow since it has to make many calls to process every word in your prompt. But a big size for Q means that generation does a lot of wasted work. We only care about one new token each time but have to multiply a big Q times K.
Neither of these is ideal.The extremes are no good, so we'll split the difference and give Q 64 rows. This means we can process 64 tokens in each forward pass. It will take at most 8 calls to process a full 512 token prompt (8*64=512). These 8 calls take the same amount of time as the first 8 tokens in the generation phase which seems like a reasonable balance. 64 is also a multiple of 8, which aligns with the ANE hardware.
The Goldilocks zone.If you are planning to process longer prompts and generate fewer tokens, you might consider a larger Q. Similarly if your prompts will frequently be shorter, a smaller Q will buy some speed. Either way, be sure to benchmark. Performance is often non-linear.
Fast Sliding Cache
To make pre-fill work, we need to be able to process 64 new tokens at a time. This means we need to compute from scratch the entire Q and also the newest 64 tokens of K on each pass through the LLM. Lucky for us, we don't need to compute the trailing 448 (512-64) tokens of K—we can reuse ones that were previously computed. These reused tokens are the K in "KV cache" and not computing them saves a whole lot of computation and time.
Typically a KV cache is implemented statefully: one long-lived K matrix that continually appends the newest token's entries each time the LLM is called.
This is a no go on ANE, so instead we use a sliding window approach. Each pass through the LLM takes the 64 new tokens and concatenates the 448 next-newest tokens to get the full 512 length K matrix.
These 448 next-newest token K matrices are passed as inputs to the LLM. This means the LLM only needs to do a single concatenation to get the full 512 K. Memory operations, like concat, are slow and only doing 1 is close to the minimum (of zero!).
Only 1 concat to get K!We do need to actually slide K though, so we return the 64 new K tokens from the model and use a secondary model to combine the old 448 and new 64 into an updated 448 K input.
We have to do this in between every LLM call during pre-fill since we want all 64 tokens to go into the cache immediately. But we only have to do it once every 64 LLM calls during generation: we can reuse the same 448 K until we have a full 64 new tokens.
Secondary model to update the cache every 64 tokens. The oldest 64 tokens are discarded.The secondary cache sliding model lets us use the ANE when it would otherwise be idle. This is actually faster than using a single model even during pre-fill. It's significantly faster during generation.
If we really hate the idea of using two models, we can make our single model return a pre-slid K matrix. This works but is slow. You have to construct and return many K matrices that you don't actually need and since these are big matrices it takes time just to shuffle them around inside the model (remember, concat is slow).
This leaves us with a nicely optimized sliding K cache:
- We return only the 64 new tokens of K. This is the minimum we can return since we need all 64 during pre-fill.
- We only perform one K concat during attention. Concat is slow so less is good.
- We slide our K cache when the ANE is otherwise idle. We only do this once per 64 tokens during generation.
Avoiding Memory Operations
We've minimized how often we concat, but that's not the only memory operation in attention. Transposing a matrix (flipping it diagonally) is slow too and we have to transpose K in order to multiply it with Q.
We can lean on our K cache here to minimize this memory movement. Instead of waiting until we have the full 512 length K in hand, we can transpose just the new 64 length K which is smaller and transposes faster. Only then do we concat it with the 448 length K, which comes into the model already transposed, to get our full 512 K.
To make this work, we output the transposed 64 K and update our secondary model to work with transposed Ks.
This is basically a free speed up.
Up to this point we have been talking about a single K cache matrix. A real transformer model has many K caches. For instance Llama 7B has 32. This is a lot of matrices to juggle so it is common to see the KV cache stored as a single tensor that contains all of them. On ANE this requires several concatenations that would be nice to avoid. To do so we take in and return each K cache individually. The extra bookkeeping is straightforward and worth it.
The Rest of Attention
The second matrix multiplication is much less interesting than the first but it is important so let's touch on it briefly. Our goal is to multiply the result of Q*K, called W, with V.
V is the same size as K so we can use the same sliding window cache approach. We don't need to do the transpose trick with V because of how the matrix shapes work out.
For convenience we can make our secondary model process both a K and a V at the same time.
That's all there is to it. You now have all the pieces of a static shaped KV cache attention that works on Apple Neural Engine.
The input/output widths are to scale, but the KV cache is much much deeper.You should see a non-trivial speed up compared to a cache-less model that processes the same number of tokens. For example I have a Llama 2 7B model that saw approximately a 4x speedup.
A (Slow) Single-Model Approach
I also want to touch on a couple things that didn't quite pan out. I'm hopeful there are opportunities to improve and maybe these will give someone an idea.
The purist in me hates using two models to juggle the KV cache. We can avoid it.
Instead of taking the 448 length K cache as an input to the model, we can take in the 7 separate 64 length chunks that make it up. Sliding our cache then just becomes a matter of removing the oldest chunk and adding the newest one.
Removing the old chunk and adding the new chunk is zero-cost. But the concat is slow.This completely eliminates the need for a second model, but it means we have to concatenate all 8 chunks to get the full K.
Sadly this concat is slow. Very slow. So this approach is dead. Unless…
No Concat Attention (Spoiler: Also Slow)
Turns out you don't actually need to concat the full K before multiplying by Q.
You can multiply each K chunk by Q individually, then hang onto some extra statistics that allow you to compute the rest of attention.
You can trade the final concat for 7x additions, but that's slow too.This is called the lazy softmax trick (link) and its main selling point is it reduces memory pressure caused by attention. That reduction is traded for, as you might guess, speed. So this is also slow.
Additionally even if it was fast we would need some creative solution to avoid concatenating and summing at the very end.
So I think this too is a dead end for now.
New Hopes
There's a couple places we can potentially squeeze more speed from:
The newest version of iOS/macOS has a feature to enable a stateful KV cache. If you don't care about old OSes, this might be worth a look.
The fact that we recompute 64 tokens each time means we could add some form of multi-token prediction basically for free. There is some research into models that predict many tokens instead of one. There are also speculative decoding methods that could work.
Tweet me or open an issue on GitHub if you have other ideas or questions!