21 Inference Optimization
KV Caching, Speculative Decoding, and Serving at Scale
Training is batch-oriented: process thousands of examples together. Inference is latency-oriented: one user waits for one response.
The optimizations are completely different.
KV caching exploits the fact that recomputing K/V for past tokens produces identical results—the computation is redundant. By caching these values (a locality optimization—keep frequently accessed data in fast memory), we eliminate O(n) redundant recomputations per generated token.
Multi-Query Attention exploits redundancy across attention heads: different query heads can share the same K/V representations with minimal quality loss.
21.1 The Inference Workload
Training: maximize throughput (samples/second) Inference: minimize latency (time per sample) while maximizing throughput (requests/second)
Key difference: Training gets large, predictable batches. Inference gets sporadic, variable-length requests.
This chapter focuses on autoregressive inference—generating sequences one token at a time—which dominates LLM serving.
21.2 Autoregressive Generation
def generate(model, prompt, max_length=100):
tokens = tokenize(prompt)
for _ in range(max_length):
# Forward pass through entire model
logits = model(tokens) # Shape: [seq_len, vocab_size]
# Sample next token
next_token = sample(logits[-1])
tokens = torch.cat([tokens, next_token])
if next_token == EOS:
break
return tokensThe problem: Each iteration processes the entire sequence, including tokens we’ve already seen.
For a 100-token generation: - Step 1: Process 10 prompt tokens - Step 2: Process 11 tokens (10 prompt + 1 generated) - Step 3: Process 12 tokens - … - Step 100: Process 110 tokens
Total tokens processed: 10 + 11 + 12 + … + 110 = 6,050 tokens
But we only generated 100 new tokens. We reprocessed prompt tokens 100 times!
21.3 KV Caching: The Foundation
Attention at position \(i\) depends on tokens \([0, ..., i]\):
\[\text{Attention}(Q_i, K_{0:i}, V_{0:i})\]
Key observation: \(K_{0:i-1}\) and \(V_{0:i-1}\) don’t change when we add token \(i\).
Idea: Cache the key/value vectors for all previous tokens.
def generate_with_kv_cache(model, prompt, max_length=100):
tokens = tokenize(prompt)
kv_cache = None # Will store (keys, values) for each layer
for _ in range(max_length):
# Only process the NEW token
if kv_cache is None:
# First iteration: process full prompt
logits, kv_cache = model(tokens, use_cache=True)
else:
# Subsequent: only process last token
logits, kv_cache = model(tokens[-1:], kv_cache=kv_cache)
next_token = sample(logits[-1])
tokens = torch.cat([tokens, next_token])
if next_token == EOS:
break
return tokens21.3.1 Memory Cost
For a model with: - \(L\) layers - \(n\) tokens in cache - \(d\) hidden dimension - \(h\) attention heads
KV cache size (standard multi-head attention): \(2 \times L \times n \times d \times \text{sizeof(dtype)}\)
With Grouped-Query Attention (GQA), K and V are shared across groups of heads, reducing the cache:
\[\text{KV cache} = 2 \times L \times n \times (\text{num\_kv\_heads} \times \text{head\_dim}) \times \text{sizeof(dtype)}\]
Example: LLaMA-2 70B (uses GQA with 8 KV heads, head_dim=128) - 80 layers - 8 KV heads (not 64 query heads) with head_dim = 128 - 2048 token context - FP16 (2 bytes)
\[\text{KV cache} = 2 \times 80 \times 2048 \times (8 \times 128) \times 2 = 0.67\text{GB per sequence}\]
Note: without GQA (full MHA with 64 KV heads), this would be \(5.2\text{GB}\) — an 8x difference. GQA is why modern LLMs can serve many concurrent sequences.
On an 80GB A100, after model weights (~35GB in INT4), the remaining memory can support dozens of concurrent sequences.
21.3.2 Compute Savings
Without KV cache: O(n²) operations to generate n tokens (quadratic) With KV cache: O(n) operations (linear)
Example: Generate 100 tokens with 10-token prompt - Without: 6,050 tokens processed - With: 110 tokens processed (10 prompt + 100 generated × 1 each) - Speedup: 55×
In practice: 10-50× speedup for typical generation lengths.
21.4 Multi-Query Attention
Standard attention: each head has separate Q, K, V projections.
Heads: 32
Hidden: 8192
Per-head: 256
Q: [8192 → 32 × 256] = 8,192 × 8,192 params
K: [8192 → 32 × 256] = 8,192 × 8,192 params
V: [8192 → 32 × 256] = 8,192 × 8,192 params
Multi-Query Attention (MQA): Share K and V across all heads.
Q: [8192 → 32 × 256] = 8,192 × 8,192 params
K: [8192 → 256] = 8,192 × 256 params ← Shared!
V: [8192 → 256] = 8,192 × 256 params ← Shared!
KV cache impact: - Standard: 2 × L × n × d - MQA: 2 × L × n × (d / h)
For 32 heads: 32× smaller KV cache
On A100 (80GB): KV cache capacity scales accordingly, but overall concurrency is still limited by weights and other overheads (so expect far fewer than 500 sequences in practice).
Tradeoff: Slight quality degradation; reported perplexity impact is model- and data-dependent (often a few percent). Worthwhile for inference.
Grouped-Query Attention (GQA): Middle ground—share K/V across groups of heads (e.g., 4-8 heads per group). Used in LLaMA-2.
21.5 PagedAttention: Efficient Memory Management
Problem: KV cache grows dynamically. Pre-allocating max length wastes memory.
Sequence A: "Hello, how are" [uses 20 slots / 2048 allocated]
Sequence B: "The quick brown..." [uses 100 slots / 2048 allocated]
Wasted memory: ~95% of allocation!
PagedAttention (from vLLM): Borrow OS virtual memory concepts.
- Divide KV cache into fixed-size “pages” (e.g., 16 tokens)
- Allocate pages on-demand as sequence grows
- Share pages across sequences (for shared prefixes)
Prompt: "Translate to French: "
└─ page 0 (cached, shared)
Request A: "Translate to French: Hello"
└─ page 0 (shared) → page 1 (unique)
Request B: "Translate to French: Goodbye"
└─ page 0 (shared) → page 2 (unique)
Benefit: 3-4× higher throughput (more concurrent requests fit in memory).
21.6 Continuous Batching
Traditional serving: Wait for batch to fill, process all together, wait again.
Time ─→
[Wait for 8 requests] [Process batch] [Wait for 8 requests] ...
↑ High latency ↑ Good throughput
Continuous batching: Add/remove requests from batch dynamically.
Batch: [A____B____C_______D__]
↑ finishes, remove from batch
New request E arrives:
Batch: [B____C_______D__E____]
Benefit: No waiting for batch to fill. Average latency reduced by 2-10×.
Implementation: vLLM, TensorRT-LLM, Text Generation Inference (TGI) all use this.
21.7 Speculative Decoding
Autoregressive generation is sequential—can’t parallelize across tokens.
Or can we?
Idea: Use a fast “draft” model to generate multiple tokens. Use the slow “target” model to verify in parallel.
def speculative_decode(target_model, draft_model, prompt, k=4):
tokens = prompt
while not done:
# Draft model generates k tokens (fast)
draft_tokens = []
draft_probs = []
for _ in range(k):
p = draft_model(tokens)
t = sample(p)
draft_tokens.append(t)
draft_probs.append(p[t])
tokens = torch.cat([tokens, t])
# Target model verifies all k tokens at once (parallel!)
target_logits = target_model(tokens) # Single forward pass
target_probs = softmax(target_logits)[-k:]
# Accept/reject each draft token
num_accepted = 0
for i in range(k):
accept_prob = min(1, target_probs[i][draft_tokens[i]] / draft_probs[i])
if random.random() < accept_prob:
num_accepted += 1
else:
# Rejection: resample from adjusted distribution
tokens = tokens[:-(k-i)] # Remove remaining draft tokens
# Sample from (target - draft)
break
return tokensKey insight: Target model processes all k draft tokens in one forward pass (parallel). If draft is good, we accept multiple tokens per iteration.
Speedup: 2-3× when draft model is 10× faster and agrees with target 60-80% of the time.
Example draft models: - Smaller version of same model (LLaMA-70B target, LLaMA-7B draft) - Quantized version - Early-exit from target model
Tradeoff: Requires additional compute for the draft model. Since the draft model is typically 10-20× smaller than the target, total compute overhead is ~1.1-1.2×, not 2×. The technique is most beneficial when the target model is memory-bound (decode phase), where the draft model’s overhead is negligible compared to the time saved by accepting multiple tokens per target forward pass.
21.8 Quantization for Inference
Training requires FP16/FP32 for gradient precision. Inference doesn’t.
INT8 quantization: - 2× memory reduction - 2-3× speedup (on tensor cores) - Minimal quality loss with calibration
INT4 quantization (GPTQ, AWQ): - 4× memory reduction - Enables larger models on same hardware - 1-2% perplexity degradation
Example: LLaMA-2 70B - FP16: 140GB (doesn’t fit in 80GB A100) - INT4: 35GB (fits easily, with room for KV cache)
See Quantization for full quantization details.
21.9 The Batch Size vs. Latency Tradeoff
Larger batches → higher throughput, higher latency per request.
Batch size 1: 10 ms latency, 100 req/sec throughput
Batch size 8: 30 ms latency, 266 req/sec throughput
Batch size 32: 80 ms latency, 400 req/sec throughput
Choosing batch size: - Interactive applications (chatbots): small batch (1-4) - Offline processing: large batch (32+) - APIs with SLA: tune to meet latency targets
21.10 Memory-Bound vs. Compute-Bound Regimes
Prefill phase (processing prompt): - Large batch of tokens processed together - High arithmetic intensity (matmuls) - Compute-bound
Decode phase (generating tokens): - One token at a time per sequence - Low arithmetic intensity (memory I/O for KV cache) - Memory-bound
Optimization strategies differ: - Prefill: Standard optimizations (FlashAttention, fusion, quantization) - Decode: KV cache optimization, memory bandwidth critical
21.11 Serving Infrastructure
Real production systems combine everything:
# Pseudocode for production LLM server
class LLMServer:
def __init__(self):
self.model = load_model_quantized() # INT4/INT8
self.kv_cache_manager = PagedKVCache()
self.batch = ContinuousBatch()
async def generate(self, request):
# Add to continuous batch
self.batch.add(request)
# Generate with KV caching
while not request.done:
# Prefill or decode
if request.is_prefill:
# Process all prompt tokens (compute-bound)
logits = self.model.forward(
request.tokens,
use_flash_attention=True
)
request.kv_cache = self.kv_cache_manager.allocate(request)
else:
# Process one token (memory-bound)
logits = self.model.forward(
request.last_token,
kv_cache=request.kv_cache
)
# Sample next token
next_token = sample(logits)
request.tokens.append(next_token)
# Continuous batching: may add/remove other requests here
await self.batch.step()
return request.tokensComponents: - Continuous batching (low latency) - KV cache with paging (high throughput) - Quantization (larger models) - FlashAttention (memory efficiency) - Mixed prefill/decode batching (maximize utilization)
21.12 Key Metrics
Latency: - Time to first token (TTFT): User-perceived latency - Inter-token latency (ITL): Generation smoothness - Total latency: TTFT + (ITL × num_tokens)
Throughput: - Requests per second - Tokens per second - GPU utilization
Efficiency: - Memory bandwidth utilization - Compute utilization - Concurrent sequences per GPU
21.13 Connections
Memory Hierarchy: KV cache is fundamentally a locality optimization—keep frequently accessed data in fast memory.
Bandwidth: Decode phase is memory-bound. KV cache size directly determines bandwidth requirements.
FlashAttention: Essential for prefill phase; enables long context with limited memory.
Quantization: Critical for fitting large models and large KV caches in GPU memory.
21.14 Key Takeaways
KV caching is mandatory: 10-50× speedup for autoregressive generation.
Multi-Query Attention: 32× smaller KV cache with minimal quality loss.
PagedAttention: 3-4× better memory utilization through dynamic allocation.
Continuous batching: 2-10× latency reduction by avoiding batch fill waits.
Speculative decoding: 2-3× speedup when memory-bound with good draft model.
Prefill vs. decode: Different regimes need different optimizations.
Quantization enables scale: INT4 lets you serve 4× larger models.
- Pope et al. (2022). “Efficiently Scaling Transformer Inference”
- Shazeer (2019). “Fast Transformer Decoding: One Write-Head is All You Need” (Multi-Query Attention)
- Kwon et al. (2023). “Efficient Memory Management for Large Language Model Serving with PagedAttention” (vLLM)
- Leviathan et al. (2023). “Fast Inference from Transformers via Speculative Decoding”
- vLLM documentation: https://vllm.readthedocs.io/