Ever puzzled why the time to first token in LLMs is excessive however subsequent tokens are tremendous quick?
On this submit, I dive into the small print of KV-Caching utilized in Mistral, a subject I initially discovered fairly daunting. Nevertheless, as I delved deeper, it grew to become an interesting topic, particularly when it defined why the time to first token (TTFT) in these language fashions is usually excessive — a sample I observed throughout numerous API calls 🙂.
I’ll cowl:
- What precisely is KV-Caching?
- The idea of the rolling cache buffer
- The prefill and decode levels
- Formulating consideration masks with the assistance of the xFormers library
Think about our enter token sequence as x1, x2, x3 … xt, and we’re figuring out the output at time step t. To seek out the eye output (at every transformer layer), we want the dot product of the present token’s question vector with the important thing vectors of the present and previous tokens. After normalizing through softmax, these turn into the eye weights over the worth vectors. Listed below are two key observations:
- Single Token Decoding: Decoding occurs one token at a time. We’re solely within the self-attention output for the present token, focusing solely on its question vector, not question vectors of different tokens.
- Precomputed Keys and Values: We want the dot product with the keys of previous tokens, which had been already computed when calculating the self-attention output of the token at time step t−1. The identical goes for the worth vectors.
The scale of the important thing portions are as follows:
- Token Embedding Vectors:
dim
- Dimension of Question, Key, Worth Heads:
head_dim
- Variety of Question Heads:
n_heads
- Variety of Key and Worth Heads:
n_kv_heads
- Variety of Transformer Layers:
n_layers
(Notice: Mistral makes use of grouped question consideration the place for every token, 4 of its question vectors attend to the identical key-value pair. With n_heads
=32, we’ve got n_kv_heads
=32/4=8)
Within the unoptimized implementation:
Assuming a single transformer layer, at every time step, we calculate the question for the present token, and the important thing and worth vectors for each the present and previous tokens. This course of includes three matrix multiplications.
a. Question Calculation (Q):
b. Key Calculation (K):
c. Worth Calculation (V):
As soon as we’ve got the question, key and worth vectors we are able to then proceed to compute the eye output utilizing —
Within the optimized implementation:
Nevertheless, as talked about in level 2, the keys and values of tokens as much as time step t−1 would have already been computed when figuring out the output at time step t−1. This implies we are able to keep away from redundant computations by storing the keys and values of tokens as much as time step t−1.
Notice: Mistral makes use of a sliding window consideration mechanism, so we solely attend to a particular variety of earlier tokens. Extra particulars on this will likely be coated later.
What this implies is that in decoding, we compute the important thing and worth vectors just for the present token and never for the earlier ones. So, operations (b) and (c) above are carried out for only one token as an alternative of t tokens. Particularly:
Key Calculation (K):
Worth Calculation (V):
FLOPS Saved
At each step of decoding, we save 2*(t-1)*n_kv_heads*dim²
FLOPS. For a sequence of size T, this interprets to financial savings of 2*(T*(T-1)/2)*n_kv_heads*dim²
FLOPS.
Contemplating we’ve assumed a single transformer layer, and figuring out that Mistral makes use of 32 transformer layers, the financial savings are multiplied by 32. That is important!
For a typical sequence size of 10,000 tokens, with n_kv_heads
=8 and dim
=4096, we get 4.294e+17 FLOPS (10000*10000*8*4096*4096*32)
An Nvidia A100 GPU has roughly 312e+12 FLOPS, that means we’d save round 23 minutes in producing this sequence of 10,000 tokens!
Notice: It is a simplified calculation to offer an concept of the advantages, that are certainly substantial. Precise enhancements will depend upon varied elements corresponding to most possible cache measurement, GPU reminiscence, parallelization with a number of GPUs, and many others.
Now that we perceive the KV cache, I’ll talk about how we leverage it throughout output technology!
First, let’s set up some terminology utilized by Mistral:
- Sliding Window Consideration (SWA): Mistral makes use of SWA, that means every token attends to itself and the earlier W−1 tokens, the place W is the window measurement.
- KV Cache Dimension: We set our KV Cache to measurement W. This implies we are able to retailer W key vectors and W worth vectors within the cache. This ensures we’ve got the required context to compute the self-attention output for the subsequent token.
- Chunk Dimension: We course of consumer enter immediate sequences additionally W tokens at a time (extra on this within the subsequent part on Prefill). This chunk measurement limits GPU reminiscence utilization. Self-attention requires K, Q, and V to be on the GPU, and these develop with the enter measurement, making it impractical to course of all the enter sequence in a single batch.
Notice:
Every transformer layer in Mistral has its personal separate KV Cache.
At first, it might sound (it did to me!) that calculating and caching solely the keys and values of the final W-1 tokens within the enter sequence could be adequate to generate the primary output token. Nevertheless, that’s not the case! It’s because Mistral has multiple transformer layer. To compute the output from the second layer of our subsequent token, we want the output of the final W−1 tokens within the first layer, which in flip is dependent upon the final (2W−1) enter tokens (much like receptive subject in CNNs!)
Mistral makes use of a window measurement of W = 4096 tokens.
The enter to those fashions normally begins with user-provided tokens (the well-known consumer immediate 😊), adopted by the technology of output tokens. The stage the place we populate the KV-cache with the keys and values from the consumer immediate, so we are able to use them when producing output tokens, is known as the prefill stage. That is the important thing purpose why the time to first token (TTFT) is usually excessive.
To grasp the workings of the prefill stage, let’s stroll by means of an instance:
Think about we’ve got 3 sequences in our inference batch with consumer immediate token lengths of 4, 1, and three respectively. Suppose we’ve got a window measurement W=3, and we wish to generate the subsequent 5 tokens for every sequence.
Given:
- seqlens = [4,1,3]
- sliding_window_size = cache_size = 3
- chunk_size = 2 (for illustration functions, ideally this may even be = W = 3 as talked about earlier than)
Within the prefill stage, since we have already got all of the enter tokens, we are able to course of them in parallel. With a chunk_size of two we require two iterations as defined beneath.
We’ve a bit measurement of two, so we’ll course of the primary 2 tokens from every sequence. This implies the sequence lengths into consideration for this step are [2,1,2].
To batch the three sequences, one strategy is to pad the shorter sequences to match the longest sequence. Nevertheless, if the sequences range drastically in size, padding ends in plenty of wasted reminiscence. Therefore, this strategy is usually not used.
The popular strategy is to concatenate all of the sequences within the batch right into a single bigger sequence. We are going to create an acceptable consideration masks in order that tokens attend solely to these inside the similar sequence.
This means our enter form is: [2+1+2,dim] = [5,dim]
We compute our Q, K, and V vectors for this enter by multiplying with matrices Wq, Wk, and Wv. Assuming the variety of heads = 1 for simplicity, the outputs can have the next shapes:
a. Q: [5, head_dim]
b. K: [5, head_dim]
c. V: [5, head_dim]
Subsequent, we add rotary positional encodings to our Q and K vectors.
With these preparations, we’re able to calculate the self-attention output!
Step 1: Retrieve from KV-Cache and Compute Consideration
Since that is the primary chunk, we have a look at the KV-cache and discover it empty — no vectors saved there. This implies there aren’t any earlier tokens to take care of, solely the present token itself. Consequently, the variety of key-value vectors (kv_seqlen) matches the variety of question vectors (q_seqlen) in every sequence.
To deal with this, we create our masks utilizing the BlockDiagonalCausalMask
from the xFormers library like so:
masks = BlockDiagonalCausalMask.from_seqlens(q_seqlen = [2,1,2], kv_seqlen=[2,1,2]).make_local_attention(window_size=3)
The eye masks may be visualized utilizing
masks.materialize(form=(5,5)).exp()
# The 'form' argument is obtained as follows: the primary dimension is the overall variety of question vectors and the second dimension is the overall variety of key/worth vectors
and the output is
[[1., 0., 0., 0., 0.],
[1., 1., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 1., 1.]]
Let’s perceive how we obtained this masks and why it is smart. Deal with q_seqlen = [2,1,2] and kv_seqlen=[2,1,2].
The primary sequence has 2 question vectors and a pair of key-value (kv) vectors. The eye masks for this sequence is the 2×2 matrix within the high left:
[[1,0],
[1,1]]
The second component within the first row is 0 as a result of this can be a causal masks, and we are not looking for the primary token to take care of the second token (sooner or later).
The second sequence has simply 1 question and 1 kv vector, represented by the middle 1×1 matrix. The third sequence, much like the primary, has an similar 2×2 matrix within the backside proper.
Discover that the eye masks for the sequences are logically concatenated alongside the diagonal.
Setting the window measurement to three in our masks creation ensures that we solely contemplate as much as 3 tokens for consideration per sequence.
This masks is utilized to the output of the matrix product of Q and K.T. Thus, dot merchandise of queries and keys from completely different sequences are nullified by the 0s within the mixed consideration matrix, preserving causality.
Notice: Underneath the hood, xFormers doesn’t calculate these dot merchandise in any respect that will be nullified by the 0s by the eye masks
The BlockDiagonalCausalMask
in xFormers begins filling 1s from the top-left of every block, which is precisely what we want for our first prefill.
Step 2: Cache Replace
Subsequent, we replace the cache with the computed keys and values. Our cache measurement is initialized to W×batch_size=W×3 that’s one for every sequence and one every for key and values. It is a rolling cache that means tokens within the first sequence will burn up cache positions [0, 1, 2, 0, 1, 2 …], tokens within the second sequence will burn up cache positions [3, 4, 5, 3, 4, 5 …] and tokens within the third sequence will burn up cache positions [6, 7, 8, 6, 7, 8 …].
So, our KV-Cache after the primary iteration (on processing 2, 1 and a pair of variety of tokens from every sequence) appears like this:
We now transfer on to the remaining a part of our sequences. The remaining tokens to course of for every sequence are [2, 0, 1]. In Mistral code, this stage is known as the ‘subsequent prefill’ stage.
Step 1: Retrieve from KV-Cache and Compute Consideration
As in iteration 1, we first have a look at the KV-cache however now we discover entries in them. We retrieve the entries and carry out and an unroll/unrotate step on them to revive the right sequence order. Why can we do that?
Bear in mind, this can be a rolling cache. If we had processed, say, 5 tokens, the queries and values for the 4th and fifth tokens would occupy the primary two cache positions, adopted by these of the third token. After unrolling, we’d have the queries and values of the third, 4th, and fifth tokens in that order. Nevertheless, on this case, since we haven’t processed greater than 3 tokens, the present cache order matches the token order.
Notice: The rationale we have to unrotate is that in the course of the prefill stage, we course of a number of tokens per sequence and we have to establish which queries ought to attend to which keys within the sequence. In distinction, in the course of the decode stage (described within the following part), we course of just one token of a sequence at a time. In that case, unrotation isn’t essential as a result of this single token will attend to all parts within the cache.
Presently, the variety of question vectors for every sequence is [2, 0, 1]. The variety of key vectors is calculated because the variety of question vectors plus the variety of legitimate entries within the cache:
kv_seqlen = [2+2, 0+1, 1+2] = [4, 1, 3]
We create the masks utilizing the make_local_attention_from_bottomright()
methodology of the BlockDiagonalMask
class from xFormers:
BlockDiagonalMask.from_seqlens(
q_seqlen=[2,0,1],
kv_seqlen=[4,1,3],
).make_local_attention_from_bottomright(window_size=3)
This masks appears like:
Just like the logic defined in Iteration 1, we’ve got three matrices concatenated diagonally, the place the rows characterize the variety of queries and the columns characterize the variety of keys in every sequence.
Right here, we have to use make_local_attention_from_bottomright()
as an alternative of make_local_attention()
, as we wish to begin from the underside proper in every block.
Step 2: Cache Replace
We retailer the computed keys and values into the cache much like iteration 1 in a rolling style. Our up to date cache then appears like this:
After the prefill stage, we transfer on to the decode stage, the place we start producing our output tokens separately.
Not like the prefill stage, the place Step 1 includes studying cache entries and computing consideration and Step 2 includes updating the cache with the brand new entries, within the decode stage we reverse these steps. First, we replace the cache with the brand new entries, after which we learn all of the entries (together with those we simply added) to compute self-attention.
This strategy works neatly as a result of decoding occurs one token at a time, and we all know all entries within the cache are inside our context window (of measurement W) and wanted for self-attention.
Step 1: Cache Replace
We compute the important thing and worth vectors for the present enter token and add them to the cache. The brand new tokens are #4, #1 and #3 for the three sequences. The up to date cache appears like this:
Step 2: Retrieve from KV-Cache and Compute Consideration
We now proceed to compute self-attention and the related masks!
- We’ve one question for every sequence within the batch, so
q_seqlen= [1, 1, 1]. - The variety of keys is the variety of legitimate entries within the cache, given by kv_seqlen = [3, 2, 3].
Within the Mistral codebase, for simplicity, they repair the eye masks form to (W×batch_size, W×batch_size) = (9,9)
We create our consideration masks once more with xFormers like so:
BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
q_seqlen=[1,1,1],
kv_padding=3,
kv_seqlen=[3,2,3]
)
This masks appears like:
We’ve 3 blocks of 1×3 matrices concatenated diagonally. Since we mounted our consideration masks to 9×9 for simplicity, our preliminary consideration rating matrix (earlier than making use of the masks) considers dot merchandise between all queries within the cache (legitimate or not) with all keys. That is evident, for instance, in sequence 2 above, the place we place a 0 within the third entry of the block to invalidate that entry.