21 Inference Optimization
KV Caching, Speculative Decoding, and Serving at Scale
Training is batch-oriented: process thousands of examples together. Inference is latency-oriented: one user waits for one response.
The optimizations are completely different.
21.1 The Inference Workload
Training: maximize throughput (samples/second) Inference: minimize latency (time per sample) while maximizing throughput (requests/second)
Key difference: Training gets large, predictable batches. Inference gets sporadic, variable-length requests.
This chapter focuses on autoregressive inference—generating sequences one token at a time—which dominates LLM serving.
21.2 Autoregressive Generation
def generate(model, prompt, max_length=100):
tokens = tokenize(prompt)
for _ in range(max_length):
# Forward pass through entire model
logits = model(tokens) # Shape: [seq_len, vocab_size]
# Sample next token
next_token = sample(logits[-1])
tokens = torch.cat([tokens, next_token])
if next_token == EOS:
break
return tokensThe problem: Each iteration processes the entire sequence, including tokens we’ve already seen.
For a 100-token generation: - Step 1: Process 10 prompt tokens - Step 2: Process 11 tokens (10 prompt + 1 generated) - Step 3: Process 12 tokens - … - Step 100: Process 110 tokens
Total tokens processed: 10 + 11 + 12 + … + 110 = 6,050 tokens
But we only generated 100 new tokens. We reprocessed prompt tokens 100 times!
21.3 KV Caching: The Foundation
Attention at position \(i\) depends on tokens \([0, ..., i]\):
\[\text{Attention}(Q_i, K_{0:i}, V_{0:i})\]
Key observation: \(K_{0:i-1}\) and \(V_{0:i-1}\) don’t change when we add token \(i\).
Idea: Cache the key/value vectors for all previous tokens.
def generate_with_kv_cache(model, prompt, max_length=100):
tokens = tokenize(prompt)
kv_cache = None # Will store (keys, values) for each layer
for _ in range(max_length):
# Only process the NEW token
if kv_cache is None:
# First iteration: process full prompt
logits, kv_cache = model(tokens, use_cache=True)
else:
# Subsequent: only process last token
logits, kv_cache = model(tokens[-1:], kv_cache=kv_cache)
next_token = sample(logits[-1])
tokens = torch.cat([tokens, next_token])
if next_token == EOS:
break
return tokens21.3.1 Memory Cost
For a model with: - \(L\) layers - \(n\) tokens in cache - \(d\) hidden dimension - \(h\) attention heads
KV cache size: \(2 \times L \times n \times d \times \text{sizeof(dtype)}\)
Example: LLaMA-2 70B - 80 layers - 8192 hidden dim - 2048 token context - FP16 (2 bytes)
\[\text{KV cache} = 2 \times 80 \times 2048 \times 8192 \times 2 = 5.2\text{GB per sequence}\]
On an 80GB A100, you can fit ~15 concurrent sequences before running out of memory.
21.3.2 Compute Savings
Without KV cache: O(n²) operations to generate n tokens (quadratic) With KV cache: O(n) operations (linear)
Example: Generate 100 tokens with 10-token prompt - Without: 6,050 tokens processed - With: 110 tokens processed (10 prompt + 100 generated × 1 each) - Speedup: 55×
In practice: 10-50× speedup for typical generation lengths.
21.4 Multi-Query Attention
Standard attention: each head has separate Q, K, V projections.
Heads: 32
Hidden: 8192
Per-head: 256
Q: [8192 → 32 × 256] = 8,192 × 8,192 params
K: [8192 → 32 × 256] = 8,192 × 8,192 params
V: [8192 → 32 × 256] = 8,192 × 8,192 params
Multi-Query Attention (MQA): Share K and V across all heads.
Q: [8192 → 32 × 256] = 8,192 × 8,192 params
K: [8192 → 256] = 8,192 × 256 params ← Shared!
V: [8192 → 256] = 8,192 × 256 params ← Shared!
KV cache impact: - Standard: 2 × L × n × d - MQA: 2 × L × n × (d / h)
For 32 heads: 32× smaller KV cache
On A100 (80GB): ~500 concurrent sequences instead of 15.
Tradeoff: Slight quality degradation (2-3% worse perplexity). Worthwhile for inference.
Grouped-Query Attention (GQA): Middle ground—share K/V across groups of heads (e.g., 4-8 heads per group). Used in LLaMA-2.
21.5 PagedAttention: Efficient Memory Management
Problem: KV cache grows dynamically. Pre-allocating max length wastes memory.
Sequence A: "Hello, how are" [uses 20 slots / 2048 allocated]
Sequence B: "The quick brown..." [uses 100 slots / 2048 allocated]
Wasted memory: ~95% of allocation!
PagedAttention (from vLLM): Borrow OS virtual memory concepts.
- Divide KV cache into fixed-size “pages” (e.g., 16 tokens)
- Allocate pages on-demand as sequence grows
- Share pages across sequences (for shared prefixes)
Prompt: "Translate to French: "
└─ page 0 (cached, shared)
Request A: "Translate to French: Hello"
└─ page 0 (shared) → page 1 (unique)
Request B: "Translate to French: Goodbye"
└─ page 0 (shared) → page 2 (unique)
Benefit: 3-4× higher throughput (more concurrent requests fit in memory).
21.6 Continuous Batching
Traditional serving: Wait for batch to fill, process all together, wait again.
Time ─→
[Wait for 8 requests] [Process batch] [Wait for 8 requests] ...
↑ High latency ↑ Good throughput
Continuous batching: Add/remove requests from batch dynamically.
Batch: [A____B____C_______D__]
↑ finishes, remove from batch
New request E arrives:
Batch: [B____C_______D__E____]
Benefit: No waiting for batch to fill. Average latency reduced by 2-10×.
Implementation: vLLM, TensorRT-LLM, Text Generation Inference (TGI) all use this.
21.7 Speculative Decoding
Autoregressive generation is sequential—can’t parallelize across tokens.
Or can we?
Idea: Use a fast “draft” model to generate multiple tokens. Use the slow “target” model to verify in parallel.
def speculative_decode(target_model, draft_model, prompt, k=4):
tokens = prompt
while not done:
# Draft model generates k tokens (fast)
draft_tokens = []
draft_probs = []
for _ in range(k):
p = draft_model(tokens)
t = sample(p)
draft_tokens.append(t)
draft_probs.append(p[t])
tokens = torch.cat([tokens, t])
# Target model verifies all k tokens at once (parallel!)
target_logits = target_model(tokens) # Single forward pass
target_probs = softmax(target_logits)[-k:]
# Accept/reject each draft token
num_accepted = 0
for i in range(k):
accept_prob = min(1, target_probs[i][draft_tokens[i]] / draft_probs[i])
if random.random() < accept_prob:
num_accepted += 1
else:
# Rejection: resample from adjusted distribution
tokens = tokens[:-(k-i)] # Remove remaining draft tokens
# Sample from (target - draft)
break
return tokensKey insight: Target model processes all k draft tokens in one forward pass (parallel). If draft is good, we accept multiple tokens per iteration.
Speedup: 2-3× when draft model is 10× faster and agrees with target 60-80% of the time.
Example draft models: - Smaller version of same model (LLaMA-70B target, LLaMA-7B draft) - Quantized version - Early-exit from target model
Tradeoff: Requires 2× compute (draft + target). Only beneficial when memory-bound, not compute-bound.
21.8 Quantization for Inference
Training requires FP16/FP32 for gradient precision. Inference doesn’t.
INT8 quantization: - 2× memory reduction - 2-3× speedup (on tensor cores) - Minimal quality loss with calibration
INT4 quantization (GPTQ, AWQ): - 4× memory reduction - Enables larger models on same hardware - 1-2% perplexity degradation
Example: LLaMA-2 70B - FP16: 140GB (doesn’t fit in 80GB A100) - INT4: 35GB (fits easily, with room for KV cache)
See Chapter 13 for full quantization details.
21.9 The Batch Size vs. Latency Tradeoff
Larger batches → higher throughput, higher latency per request.
Batch size 1: 10 ms latency, 100 req/sec throughput
Batch size 8: 30 ms latency, 266 req/sec throughput
Batch size 32: 80 ms latency, 400 req/sec throughput
Choosing batch size: - Interactive applications (chatbots): small batch (1-4) - Offline processing: large batch (32+) - APIs with SLA: tune to meet latency targets
21.10 Memory-Bound vs. Compute-Bound Regimes
Prefill phase (processing prompt): - Large batch of tokens processed together - High arithmetic intensity (matmuls) - Compute-bound
Decode phase (generating tokens): - One token at a time per sequence - Low arithmetic intensity (memory I/O for KV cache) - Memory-bound
Optimization strategies differ: - Prefill: Standard optimizations (FlashAttention, fusion, quantization) - Decode: KV cache optimization, memory bandwidth critical
21.11 Serving Infrastructure
Real production systems combine everything:
# Pseudocode for production LLM server
class LLMServer:
def __init__(self):
self.model = load_model_quantized() # INT4/INT8
self.kv_cache_manager = PagedKVCache()
self.batch = ContinuousBatch()
async def generate(self, request):
# Add to continuous batch
self.batch.add(request)
# Generate with KV caching
while not request.done:
# Prefill or decode
if request.is_prefill:
# Process all prompt tokens (compute-bound)
logits = self.model.forward(
request.tokens,
use_flash_attention=True
)
request.kv_cache = self.kv_cache_manager.allocate(request)
else:
# Process one token (memory-bound)
logits = self.model.forward(
request.last_token,
kv_cache=request.kv_cache
)
# Sample next token
next_token = sample(logits)
request.tokens.append(next_token)
# Continuous batching: may add/remove other requests here
await self.batch.step()
return request.tokensComponents: - Continuous batching (low latency) - KV cache with paging (high throughput) - Quantization (larger models) - FlashAttention (memory efficiency) - Mixed prefill/decode batching (maximize utilization)
21.12 Key Metrics
Latency: - Time to first token (TTFT): User-perceived latency - Inter-token latency (ITL): Generation smoothness - Total latency: TTFT + (ITL × num_tokens)
Throughput: - Requests per second - Tokens per second - GPU utilization
Efficiency: - Memory bandwidth utilization - Compute utilization - Concurrent sequences per GPU
21.13 Connections
Chapter 1 (Memory Hierarchy): KV cache is fundamentally a locality optimization—keep frequently accessed data in fast memory.
Chapter 2 (Bandwidth): Decode phase is memory-bound. KV cache size directly determines bandwidth requirements.
Chapter 10 (FlashAttention): Essential for prefill phase; enables long context with limited memory.
Chapter 13 (Quantization): Critical for fitting large models and large KV caches in GPU memory.
21.14 Key Takeaways
KV caching is mandatory: 10-50× speedup for autoregressive generation.
Multi-Query Attention: 32× smaller KV cache with minimal quality loss.
PagedAttention: 3-4× better memory utilization through dynamic allocation.
Continuous batching: 2-10× latency reduction by avoiding batch fill waits.
Speculative decoding: 2-3× speedup when memory-bound with good draft model.
Prefill vs. decode: Different regimes need different optimizations.
Quantization enables scale: INT4 lets you serve 4× larger models.
- Pope et al. (2022). “Efficiently Scaling Transformer Inference”
- Shazeer (2019). “Fast Transformer Decoding: One Write-Head is All You Need” (Multi-Query Attention)
- Kwon et al. (2023). “Efficient Memory Management for Large Language Model Serving with PagedAttention” (vLLM)
- Leviathan et al. (2023). “Fast Inference from Transformers via Speculative Decoding”
- vLLM documentation: https://vllm.readthedocs.io/