21  Inference Optimization

KV Caching, Speculative Decoding, and Serving at Scale

Training is batch-oriented: process thousands of examples together. Inference is latency-oriented: one user waits for one response.

The optimizations are completely different.

21.1 The Inference Workload

Training: maximize throughput (samples/second) Inference: minimize latency (time per sample) while maximizing throughput (requests/second)

Key difference: Training gets large, predictable batches. Inference gets sporadic, variable-length requests.

This chapter focuses on autoregressive inference—generating sequences one token at a time—which dominates LLM serving.

21.2 Autoregressive Generation

def generate(model, prompt, max_length=100):
    tokens = tokenize(prompt)

    for _ in range(max_length):
        # Forward pass through entire model
        logits = model(tokens)  # Shape: [seq_len, vocab_size]

        # Sample next token
        next_token = sample(logits[-1])
        tokens = torch.cat([tokens, next_token])

        if next_token == EOS:
            break

    return tokens

The problem: Each iteration processes the entire sequence, including tokens we’ve already seen.

For a 100-token generation: - Step 1: Process 10 prompt tokens - Step 2: Process 11 tokens (10 prompt + 1 generated) - Step 3: Process 12 tokens - … - Step 100: Process 110 tokens

Total tokens processed: 10 + 11 + 12 + … + 110 = 6,050 tokens

But we only generated 100 new tokens. We reprocessed prompt tokens 100 times!

21.3 KV Caching: The Foundation

Attention at position \(i\) depends on tokens \([0, ..., i]\):

\[\text{Attention}(Q_i, K_{0:i}, V_{0:i})\]

Key observation: \(K_{0:i-1}\) and \(V_{0:i-1}\) don’t change when we add token \(i\).

Idea: Cache the key/value vectors for all previous tokens.

def generate_with_kv_cache(model, prompt, max_length=100):
    tokens = tokenize(prompt)
    kv_cache = None  # Will store (keys, values) for each layer

    for _ in range(max_length):
        # Only process the NEW token
        if kv_cache is None:
            # First iteration: process full prompt
            logits, kv_cache = model(tokens, use_cache=True)
        else:
            # Subsequent: only process last token
            logits, kv_cache = model(tokens[-1:], kv_cache=kv_cache)

        next_token = sample(logits[-1])
        tokens = torch.cat([tokens, next_token])

        if next_token == EOS:
            break

    return tokens

21.3.1 Memory Cost

For a model with: - \(L\) layers - \(n\) tokens in cache - \(d\) hidden dimension - \(h\) attention heads

KV cache size: \(2 \times L \times n \times d \times \text{sizeof(dtype)}\)

Example: LLaMA-2 70B - 80 layers - 8192 hidden dim - 2048 token context - FP16 (2 bytes)

\[\text{KV cache} = 2 \times 80 \times 2048 \times 8192 \times 2 = 5.2\text{GB per sequence}\]

On an 80GB A100, you can fit ~15 concurrent sequences before running out of memory.

21.3.2 Compute Savings

Without KV cache: O(n²) operations to generate n tokens (quadratic) With KV cache: O(n) operations (linear)

Example: Generate 100 tokens with 10-token prompt - Without: 6,050 tokens processed - With: 110 tokens processed (10 prompt + 100 generated × 1 each) - Speedup: 55×

In practice: 10-50× speedup for typical generation lengths.

21.4 Multi-Query Attention

Standard attention: each head has separate Q, K, V projections.

Heads: 32
Hidden: 8192
Per-head: 256

Q: [8192 → 32 × 256] = 8,192 × 8,192 params
K: [8192 → 32 × 256] = 8,192 × 8,192 params
V: [8192 → 32 × 256] = 8,192 × 8,192 params

Multi-Query Attention (MQA): Share K and V across all heads.

Q: [8192 → 32 × 256] = 8,192 × 8,192 params
K: [8192 → 256]      = 8,192 × 256 params  ← Shared!
V: [8192 → 256]      = 8,192 × 256 params  ← Shared!

KV cache impact: - Standard: 2 × L × n × d - MQA: 2 × L × n × (d / h)

For 32 heads: 32× smaller KV cache

On A100 (80GB): ~500 concurrent sequences instead of 15.

Tradeoff: Slight quality degradation (2-3% worse perplexity). Worthwhile for inference.

Grouped-Query Attention (GQA): Middle ground—share K/V across groups of heads (e.g., 4-8 heads per group). Used in LLaMA-2.

21.5 PagedAttention: Efficient Memory Management

Problem: KV cache grows dynamically. Pre-allocating max length wastes memory.

Sequence A: "Hello, how are"     [uses 20 slots / 2048 allocated]
Sequence B: "The quick brown..." [uses 100 slots / 2048 allocated]

Wasted memory: ~95% of allocation!

PagedAttention (from vLLM): Borrow OS virtual memory concepts.

  • Divide KV cache into fixed-size “pages” (e.g., 16 tokens)
  • Allocate pages on-demand as sequence grows
  • Share pages across sequences (for shared prefixes)
Prompt: "Translate to French: "
        └─ page 0 (cached, shared)

Request A: "Translate to French: Hello"
           └─ page 0 (shared) → page 1 (unique)

Request B: "Translate to French: Goodbye"
           └─ page 0 (shared) → page 2 (unique)

Benefit: 3-4× higher throughput (more concurrent requests fit in memory).

21.6 Continuous Batching

Traditional serving: Wait for batch to fill, process all together, wait again.

Time ─→

[Wait for 8 requests] [Process batch] [Wait for 8 requests] ...
     ↑ High latency           ↑ Good throughput

Continuous batching: Add/remove requests from batch dynamically.

Batch: [A____B____C_______D__]
       ↑ finishes, remove from batch

New request E arrives:
Batch: [B____C_______D__E____]

Benefit: No waiting for batch to fill. Average latency reduced by 2-10×.

Implementation: vLLM, TensorRT-LLM, Text Generation Inference (TGI) all use this.

21.7 Speculative Decoding

Autoregressive generation is sequential—can’t parallelize across tokens.

Or can we?

Idea: Use a fast “draft” model to generate multiple tokens. Use the slow “target” model to verify in parallel.

def speculative_decode(target_model, draft_model, prompt, k=4):
    tokens = prompt

    while not done:
        # Draft model generates k tokens (fast)
        draft_tokens = []
        draft_probs = []

        for _ in range(k):
            p = draft_model(tokens)
            t = sample(p)
            draft_tokens.append(t)
            draft_probs.append(p[t])
            tokens = torch.cat([tokens, t])

        # Target model verifies all k tokens at once (parallel!)
        target_logits = target_model(tokens)  # Single forward pass
        target_probs = softmax(target_logits)[-k:]

        # Accept/reject each draft token
        num_accepted = 0
        for i in range(k):
            accept_prob = min(1, target_probs[i][draft_tokens[i]] / draft_probs[i])

            if random.random() < accept_prob:
                num_accepted += 1
            else:
                # Rejection: resample from adjusted distribution
                tokens = tokens[:-(k-i)]  # Remove remaining draft tokens
                # Sample from (target - draft)
                break

    return tokens

Key insight: Target model processes all k draft tokens in one forward pass (parallel). If draft is good, we accept multiple tokens per iteration.

Speedup: 2-3× when draft model is 10× faster and agrees with target 60-80% of the time.

Example draft models: - Smaller version of same model (LLaMA-70B target, LLaMA-7B draft) - Quantized version - Early-exit from target model

Tradeoff: Requires 2× compute (draft + target). Only beneficial when memory-bound, not compute-bound.

21.8 Quantization for Inference

Training requires FP16/FP32 for gradient precision. Inference doesn’t.

INT8 quantization: - 2× memory reduction - 2-3× speedup (on tensor cores) - Minimal quality loss with calibration

INT4 quantization (GPTQ, AWQ): - 4× memory reduction - Enables larger models on same hardware - 1-2% perplexity degradation

Example: LLaMA-2 70B - FP16: 140GB (doesn’t fit in 80GB A100) - INT4: 35GB (fits easily, with room for KV cache)

See Chapter 13 for full quantization details.

21.9 The Batch Size vs. Latency Tradeoff

Larger batches → higher throughput, higher latency per request.

Batch size 1:  10 ms latency, 100 req/sec throughput
Batch size 8:  30 ms latency, 266 req/sec throughput
Batch size 32: 80 ms latency, 400 req/sec throughput

Choosing batch size: - Interactive applications (chatbots): small batch (1-4) - Offline processing: large batch (32+) - APIs with SLA: tune to meet latency targets

21.10 Memory-Bound vs. Compute-Bound Regimes

Prefill phase (processing prompt): - Large batch of tokens processed together - High arithmetic intensity (matmuls) - Compute-bound

Decode phase (generating tokens): - One token at a time per sequence - Low arithmetic intensity (memory I/O for KV cache) - Memory-bound

Optimization strategies differ: - Prefill: Standard optimizations (FlashAttention, fusion, quantization) - Decode: KV cache optimization, memory bandwidth critical

21.11 Serving Infrastructure

Real production systems combine everything:

# Pseudocode for production LLM server
class LLMServer:
    def __init__(self):
        self.model = load_model_quantized()  # INT4/INT8
        self.kv_cache_manager = PagedKVCache()
        self.batch = ContinuousBatch()

    async def generate(self, request):
        # Add to continuous batch
        self.batch.add(request)

        # Generate with KV caching
        while not request.done:
            # Prefill or decode
            if request.is_prefill:
                # Process all prompt tokens (compute-bound)
                logits = self.model.forward(
                    request.tokens,
                    use_flash_attention=True
                )
                request.kv_cache = self.kv_cache_manager.allocate(request)
            else:
                # Process one token (memory-bound)
                logits = self.model.forward(
                    request.last_token,
                    kv_cache=request.kv_cache
                )

            # Sample next token
            next_token = sample(logits)
            request.tokens.append(next_token)

            # Continuous batching: may add/remove other requests here
            await self.batch.step()

        return request.tokens

Components: - Continuous batching (low latency) - KV cache with paging (high throughput) - Quantization (larger models) - FlashAttention (memory efficiency) - Mixed prefill/decode batching (maximize utilization)

21.12 Key Metrics

Latency: - Time to first token (TTFT): User-perceived latency - Inter-token latency (ITL): Generation smoothness - Total latency: TTFT + (ITL × num_tokens)

Throughput: - Requests per second - Tokens per second - GPU utilization

Efficiency: - Memory bandwidth utilization - Compute utilization - Concurrent sequences per GPU

21.13 Connections

Chapter 1 (Memory Hierarchy): KV cache is fundamentally a locality optimization—keep frequently accessed data in fast memory.

Chapter 2 (Bandwidth): Decode phase is memory-bound. KV cache size directly determines bandwidth requirements.

Chapter 10 (FlashAttention): Essential for prefill phase; enables long context with limited memory.

Chapter 13 (Quantization): Critical for fitting large models and large KV caches in GPU memory.

21.14 Key Takeaways

  1. KV caching is mandatory: 10-50× speedup for autoregressive generation.

  2. Multi-Query Attention: 32× smaller KV cache with minimal quality loss.

  3. PagedAttention: 3-4× better memory utilization through dynamic allocation.

  4. Continuous batching: 2-10× latency reduction by avoiding batch fill waits.

  5. Speculative decoding: 2-3× speedup when memory-bound with good draft model.

  6. Prefill vs. decode: Different regimes need different optimizations.

  7. Quantization enables scale: INT4 lets you serve 4× larger models.

NoteFurther Reading
  • Pope et al. (2022). “Efficiently Scaling Transformer Inference”
  • Shazeer (2019). “Fast Transformer Decoding: One Write-Head is All You Need” (Multi-Query Attention)
  • Kwon et al. (2023). “Efficient Memory Management for Large Language Model Serving with PagedAttention” (vLLM)
  • Leviathan et al. (2023). “Fast Inference from Transformers via Speculative Decoding”
  • vLLM documentation: https://vllm.readthedocs.io/