KV-cache reuse, explained slowly
A walkthrough with diagrams — what gets shared between requests, what doesn't, and where it goes wrong.
On this pageClick to expand
- The basic problem
- Enter: KV-cache
- How it works in practice
- Single request, greedy decoding
- Multiple requests (batching)
- Where KV-cache reuse fails
- Problem 1: Beam search and speculative decoding
- Problem 2: Prefix sharing (multi-turn conversations)
- Problem 3: Sequence length change (e.g., pipeline parallelism)
- Best practices
- 1. Reuse caches within a request
- 2. Be careful with batch resizing
- 3. Monitor cache memory usage
- 4. Use vLLM if possible
- Conclusion
KV-Cache Reuse, Explained Slowly
KV-cache reuse is one of the highest-leverage optimizations for LLM serving. It's also one of the most confusing.
Let me walk through how it works, where it helps, and where it breaks.
The basic problem
When you run inference on a transformer:
Input: "The cat sat on the"
Tokens: [The, cat, sat, on, the]
For each position, attention looks at ALL previous positions:
Position 0 (The): attends to []
Position 1 (cat): attends to [The]
Position 2 (sat): attends to [The, cat]
Position 3 (on): attends to [The, cat, sat]
Position 4 (the): attends to [The, cat, sat, on]
Position 5 (mat): attends to [The, cat, sat, on, the]
Each attention computation is expensive:
For each new token generation:
1. Embed the new token
2. Run it through all transformer layers
3. At each layer, compute attention over ALL previous tokens
4. Compute K and V matrices for the new token
5. Multiply it with all cached K and V matrices from previous tokens
The wasteful part: we recompute K and V for all previous tokens every time we generate a new token.
Enter: KV-cache
Instead of recomputing, we cache it:
First token forward pass:
- Compute K, V for token 0 → store in cache
- Attention over [K₀] and [V₀]
- Output logits
Second token forward pass:
- Only compute K, V for token 1 (the new token)
- Use cached [K₀, V₁] and [V₀, V₁]
- Don't recompute K, V for token 0!
- Attention over all positions
- Output logits
Nth token forward pass:
- Only compute K, V for token N
- Use cached [K₀, K₁, ..., K_{N-1}, K_N] and [V₀, ..., V_N]
- Output logits
The savings are enormous:
Without KV-cache:
- For 100-token generation, we compute K/V ~5,000 times (100 * 50 avg context)
- Each token takes O(context_length²) in attention
With KV-cache:
- For 100-token generation, we compute K/V exactly 100 times
- Each token takes O(context_length) in attention
- 50x faster in theory, 5–10x faster in practice
How it works in practice
Single request, greedy decoding
User input: "Translate to French: Hello"
Tokens: [Translate, to, French, :, Hello]
Step 1:
Feed all 5 tokens through model
- Compute K, V at each layer for all 5 positions
- Cache it: cache[0:5] = {K: [K₀, K₁, K₂, K₃, K₄], V: [V₀, V₁, V₂, V₃, V₄]}
- Get next token logits
Step 2: Generate token 6 (e.g., "Bonjour")
Feed only token 6 through model
- Compute K₆, V₆
- Prepend to cache: cache[0:6] = {K: [K₀,...,K₅], V: [V₀,...,V₅]}
- Attend over all 6 positions using cached K, V
- Get next token logits
Step 3: Generate token 7
Feed only token 7 through model
- Compute K₇, V₇
- Append to cache
- Attend over all 7 positions
- Get next token logits
...repeat until stop token
Multiple requests (batching)
This is where it gets tricky.
Request A: "Hello" → Target: "Bonjour"
Request B: "Goodbye" → Target: "Au revoir"
Batch step 1 (prompt phase):
Process both prompts:
Cache_A[0:1] = K, V for "Hello"
Cache_B[0:1] = K, V for "Goodbye"
Batch step 2 (generation phase, token 1):
Process both requests:
Cache_A[0:2] = Cache_A[0:1] + new K, V for "Bonjour"
Cache_B[0:2] = Cache_B[0:1] + new K, V for "Au"
Batch step 3 (generation phase, token 2):
Cache_A[0:3] = Cache_A[0:2] + new K, V for " revoir"
Cache_B[0:3] = Cache_B[0:2] + new K, V for " revoir"
Each request has its own KV cache. The caches grow at different rates if requests have different lengths.
This is fine. It's the next part that breaks things.
Where KV-cache reuse fails
Problem 1: Beam search and speculative decoding
Generating with beam size 2:
Step 1: Generate 2 candidate tokens
beam_1_candidate: "Bonjour"
beam_2_candidate: "Hello"
Step 2: Both need to extend
beam_1 needs Cache[0:6] for position 7
beam_2 needs Cache[0:6] for position 7
But they have different token 6! So they have different K₆, V₆.
You can't reuse the same cache. You need separate caches:
Cache_beam1[0:6] with "Bonjour"
Cache_beam2[0:6] with "Hello"
Solution: Split the cache per beam. This reduces your memory savings by beam_size.
Problem 2: Prefix sharing (multi-turn conversations)
Turn 1:
User: "What is 2+2?"
Assistant: "4"
Cache after turn 1: K, V for [What, is, 2, +, 2, ?, 4]
Turn 2:
User: "What is 3+3?"
Cache needed: K, V for [What, is, 2, +, 2, ?, 4, What, is, 3, +, 3, ?]
Can we reuse the old cache? NO! Because:
- The first "What" in turn 2 has a different position
- Attention positions matter
- K[0] and V[0] from turn 1 encode "position 0"
- They can't be reused for turn 2 because the position indices have shifted
Solution: Prefix KV-cache (vLLM does this). You re-index the cache to account for position shifts. More complex but enables multi-turn memory.
Problem 3: Sequence length change (e.g., pipeline parallelism)
If your model is split across GPUs (pipeline parallelism):
GPU 0: Layers 1-8
GPU 1: Layers 9-16
Forward pass for token generation:
GPU 0 computes K, V for layers 1-8
GPU 1 computes K, V for layers 9-16
But what if the attention head count differs per layer? What if you're quantizing?
You must recompute K, V on GPU 1 because GPU 0's cache format might not match GPU 1's expectations.
This means no cross-GPU cache reuse — kill your speedup.
Best practices
1. Reuse caches within a request
Always. The speedup is 5–10x. Do it.
2. Be careful with batch resizing
BAD:
Batch 1: [req1, req2]
Batch 2: [req1, req2, req3] ← Can't reuse cache for req1, req2!
GOOD:
Keep requests in the same slot in the batch
Batch 1: [req1, req2, None]
Batch 2: [req1, req2, req3]
Cache stays valid for req1 and req2
3. Monitor cache memory usage
Cache size ≈ batch_size × seq_length × num_heads × head_dim × 2 (K + V) × num_layers
For llama-7b at batch_size 32, seq_length 2048:
≈ 32 × 2048 × 32 heads × 128 dim × 2 × 32 layers
≈ 32 GB
Your GPU might have less. Solution: PagedAttention (virtual memory for KV cache).
4. Use vLLM if possible
vLLM handles most of this for you. Don't reinvent it.
Conclusion
KV-cache reuse is a free 5–10x win for autoregressive generation. But it breaks in subtle ways (beam search, multi-turn, distributed inference).
Understand where it works. Know when to disable it. Your latency depends on it.