21  Inference Optimization

KV Caching, Speculative Decoding, and Serving at Scale

Training is batch-oriented: process thousands of examples together. Inference is latency-oriented: one user waits for one response.

The optimizations are completely different.

NoteProperty Spotlight: Redundancy + Locality

KV caching exploits the fact that recomputing K/V for past tokens produces identical results—the computation is redundant. By caching these values (a locality optimization—keep frequently accessed data in fast memory), we eliminate O(n) redundant recomputations per generated token.

Multi-Query Attention exploits redundancy across attention heads: different query heads can share the same K/V representations with minimal quality loss.

21.1 The Inference Workload

Training: maximize throughput (samples/second) Inference: minimize latency (time per sample) while maximizing throughput (requests/second)

Key difference: Training gets large, predictable batches. Inference gets sporadic, variable-length requests.

This chapter focuses on autoregressive inference—generating sequences one token at a time—which dominates LLM serving.

21.2 Autoregressive Generation

def generate(model, prompt, max_length=100):
    tokens = tokenize(prompt)

    for _ in range(max_length):
        # Forward pass through entire model
        logits = model(tokens)  # Shape: [seq_len, vocab_size]

        # Sample next token
        next_token = sample(logits[-1])
        tokens = torch.cat([tokens, next_token])

        if next_token == EOS:
            break

    return tokens

The problem: Each iteration processes the entire sequence, including tokens we’ve already seen.

For a 100-token generation: - Step 1: Process 10 prompt tokens - Step 2: Process 11 tokens (10 prompt + 1 generated) - Step 3: Process 12 tokens - … - Step 100: Process 110 tokens

Total tokens processed: 10 + 11 + 12 + … + 110 = 6,050 tokens

But we only generated 100 new tokens. We reprocessed prompt tokens 100 times!

21.3 KV Caching: The Foundation

Attention at position \(i\) depends on tokens \([0, ..., i]\):

\[\text{Attention}(Q_i, K_{0:i}, V_{0:i})\]

Key observation: \(K_{0:i-1}\) and \(V_{0:i-1}\) don’t change when we add token \(i\).

Idea: Cache the key/value vectors for all previous tokens.

def generate_with_kv_cache(model, prompt, max_length=100):
    tokens = tokenize(prompt)
    kv_cache = None  # Will store (keys, values) for each layer

    for _ in range(max_length):
        # Only process the NEW token
        if kv_cache is None:
            # First iteration: process full prompt
            logits, kv_cache = model(tokens, use_cache=True)
        else:
            # Subsequent: only process last token
            logits, kv_cache = model(tokens[-1:], kv_cache=kv_cache)

        next_token = sample(logits[-1])
        tokens = torch.cat([tokens, next_token])

        if next_token == EOS:
            break

    return tokens

21.3.1 Memory Cost

For a model with: - \(L\) layers - \(n\) tokens in cache - \(d\) hidden dimension - \(h\) attention heads

KV cache size (standard multi-head attention): \(2 \times L \times n \times d \times \text{sizeof(dtype)}\)

With Grouped-Query Attention (GQA), K and V are shared across groups of heads, reducing the cache:

\[\text{KV cache} = 2 \times L \times n \times (\text{num\_kv\_heads} \times \text{head\_dim}) \times \text{sizeof(dtype)}\]

Example: LLaMA-2 70B (uses GQA with 8 KV heads, head_dim=128) - 80 layers - 8 KV heads (not 64 query heads) with head_dim = 128 - 2048 token context - FP16 (2 bytes)

\[\text{KV cache} = 2 \times 80 \times 2048 \times (8 \times 128) \times 2 = 0.67\text{GB per sequence}\]

Note: without GQA (full MHA with 64 KV heads), this would be \(5.2\text{GB}\) — an 8x difference. GQA is why modern LLMs can serve many concurrent sequences.

On an 80GB A100, after model weights (~35GB in INT4), the remaining memory can support dozens of concurrent sequences.

21.3.2 Compute Savings

Without KV cache: O(n²) operations to generate n tokens (quadratic) With KV cache: O(n) operations (linear)

Example: Generate 100 tokens with 10-token prompt - Without: 6,050 tokens processed - With: 110 tokens processed (10 prompt + 100 generated × 1 each) - Speedup: 55×

In practice: 10-50× speedup for typical generation lengths.

21.4 Multi-Query Attention

Standard attention: each head has separate Q, K, V projections.

Heads: 32
Hidden: 8192
Per-head: 256

Q: [8192 → 32 × 256] = 8,192 × 8,192 params
K: [8192 → 32 × 256] = 8,192 × 8,192 params
V: [8192 → 32 × 256] = 8,192 × 8,192 params

Multi-Query Attention (MQA): Share K and V across all heads.

Q: [8192 → 32 × 256] = 8,192 × 8,192 params
K: [8192 → 256]      = 8,192 × 256 params  ← Shared!
V: [8192 → 256]      = 8,192 × 256 params  ← Shared!

KV cache impact: - Standard: 2 × L × n × d - MQA: 2 × L × n × (d / h)

For 32 heads: 32× smaller KV cache

On A100 (80GB): KV cache capacity scales accordingly, but overall concurrency is still limited by weights and other overheads (so expect far fewer than 500 sequences in practice).

Tradeoff: Slight quality degradation; reported perplexity impact is model- and data-dependent (often a few percent). Worthwhile for inference.

Grouped-Query Attention (GQA): Middle ground—share K/V across groups of heads (e.g., 4-8 heads per group). Used in LLaMA-2.

21.5 PagedAttention: Efficient Memory Management

Problem: KV cache grows dynamically. Pre-allocating max length wastes memory.

Sequence A: "Hello, how are"     [uses 20 slots / 2048 allocated]
Sequence B: "The quick brown..." [uses 100 slots / 2048 allocated]

Wasted memory: ~95% of allocation!

PagedAttention (from vLLM): Borrow OS virtual memory concepts.

  • Divide KV cache into fixed-size “pages” (e.g., 16 tokens)
  • Allocate pages on-demand as sequence grows
  • Share pages across sequences (for shared prefixes)
Prompt: "Translate to French: "
        └─ page 0 (cached, shared)

Request A: "Translate to French: Hello"
           └─ page 0 (shared) → page 1 (unique)

Request B: "Translate to French: Goodbye"
           └─ page 0 (shared) → page 2 (unique)

Benefit: 3-4× higher throughput (more concurrent requests fit in memory).

21.6 Continuous Batching

Traditional serving: Wait for batch to fill, process all together, wait again.

Time ─→

[Wait for 8 requests] [Process batch] [Wait for 8 requests] ...
     ↑ High latency           ↑ Good throughput

Continuous batching: Add/remove requests from batch dynamically.

Batch: [A____B____C_______D__]
       ↑ finishes, remove from batch

New request E arrives:
Batch: [B____C_______D__E____]

Benefit: No waiting for batch to fill. Average latency reduced by 2-10×.

Implementation: vLLM, TensorRT-LLM, Text Generation Inference (TGI) all use this.

21.7 Speculative Decoding

Autoregressive generation is sequential—can’t parallelize across tokens.

Or can we?

Idea: Use a fast “draft” model to generate multiple tokens. Use the slow “target” model to verify in parallel.

def speculative_decode(target_model, draft_model, prompt, k=4):
    tokens = prompt

    while not done:
        # Draft model generates k tokens (fast)
        draft_tokens = []
        draft_probs = []

        for _ in range(k):
            p = draft_model(tokens)
            t = sample(p)
            draft_tokens.append(t)
            draft_probs.append(p[t])
            tokens = torch.cat([tokens, t])

        # Target model verifies all k tokens at once (parallel!)
        target_logits = target_model(tokens)  # Single forward pass
        target_probs = softmax(target_logits)[-k:]

        # Accept/reject each draft token
        num_accepted = 0
        for i in range(k):
            accept_prob = min(1, target_probs[i][draft_tokens[i]] / draft_probs[i])

            if random.random() < accept_prob:
                num_accepted += 1
            else:
                # Rejection: resample from adjusted distribution
                tokens = tokens[:-(k-i)]  # Remove remaining draft tokens
                # Sample from (target - draft)
                break

    return tokens

Key insight: Target model processes all k draft tokens in one forward pass (parallel). If draft is good, we accept multiple tokens per iteration.

Speedup: 2-3× when draft model is 10× faster and agrees with target 60-80% of the time.

Example draft models: - Smaller version of same model (LLaMA-70B target, LLaMA-7B draft) - Quantized version - Early-exit from target model

Tradeoff: Requires additional compute for the draft model. Since the draft model is typically 10-20× smaller than the target, total compute overhead is ~1.1-1.2×, not 2×. The technique is most beneficial when the target model is memory-bound (decode phase), where the draft model’s overhead is negligible compared to the time saved by accepting multiple tokens per target forward pass.

21.8 Quantization for Inference

Training requires FP16/FP32 for gradient precision. Inference doesn’t.

INT8 quantization: - 2× memory reduction - 2-3× speedup (on tensor cores) - Minimal quality loss with calibration

INT4 quantization (GPTQ, AWQ): - 4× memory reduction - Enables larger models on same hardware - 1-2% perplexity degradation

Example: LLaMA-2 70B - FP16: 140GB (doesn’t fit in 80GB A100) - INT4: 35GB (fits easily, with room for KV cache)

See Quantization for full quantization details.

21.9 The Batch Size vs. Latency Tradeoff

Larger batches → higher throughput, higher latency per request.

Batch size 1:  10 ms latency, 100 req/sec throughput
Batch size 8:  30 ms latency, 266 req/sec throughput
Batch size 32: 80 ms latency, 400 req/sec throughput

Choosing batch size: - Interactive applications (chatbots): small batch (1-4) - Offline processing: large batch (32+) - APIs with SLA: tune to meet latency targets

21.10 Memory-Bound vs. Compute-Bound Regimes

Prefill phase (processing prompt): - Large batch of tokens processed together - High arithmetic intensity (matmuls) - Compute-bound

Decode phase (generating tokens): - One token at a time per sequence - Low arithmetic intensity (memory I/O for KV cache) - Memory-bound

Optimization strategies differ: - Prefill: Standard optimizations (FlashAttention, fusion, quantization) - Decode: KV cache optimization, memory bandwidth critical

21.11 Serving Infrastructure

Real production systems combine everything:

# Pseudocode for production LLM server
class LLMServer:
    def __init__(self):
        self.model = load_model_quantized()  # INT4/INT8
        self.kv_cache_manager = PagedKVCache()
        self.batch = ContinuousBatch()

    async def generate(self, request):
        # Add to continuous batch
        self.batch.add(request)

        # Generate with KV caching
        while not request.done:
            # Prefill or decode
            if request.is_prefill:
                # Process all prompt tokens (compute-bound)
                logits = self.model.forward(
                    request.tokens,
                    use_flash_attention=True
                )
                request.kv_cache = self.kv_cache_manager.allocate(request)
            else:
                # Process one token (memory-bound)
                logits = self.model.forward(
                    request.last_token,
                    kv_cache=request.kv_cache
                )

            # Sample next token
            next_token = sample(logits)
            request.tokens.append(next_token)

            # Continuous batching: may add/remove other requests here
            await self.batch.step()

        return request.tokens

Components: - Continuous batching (low latency) - KV cache with paging (high throughput) - Quantization (larger models) - FlashAttention (memory efficiency) - Mixed prefill/decode batching (maximize utilization)

21.12 Key Metrics

Latency: - Time to first token (TTFT): User-perceived latency - Inter-token latency (ITL): Generation smoothness - Total latency: TTFT + (ITL × num_tokens)

Throughput: - Requests per second - Tokens per second - GPU utilization

Efficiency: - Memory bandwidth utilization - Compute utilization - Concurrent sequences per GPU

21.13 Connections

Memory Hierarchy: KV cache is fundamentally a locality optimization—keep frequently accessed data in fast memory.

Bandwidth: Decode phase is memory-bound. KV cache size directly determines bandwidth requirements.

FlashAttention: Essential for prefill phase; enables long context with limited memory.

Quantization: Critical for fitting large models and large KV caches in GPU memory.

21.14 Key Takeaways

  1. KV caching is mandatory: 10-50× speedup for autoregressive generation.

  2. Multi-Query Attention: 32× smaller KV cache with minimal quality loss.

  3. PagedAttention: 3-4× better memory utilization through dynamic allocation.

  4. Continuous batching: 2-10× latency reduction by avoiding batch fill waits.

  5. Speculative decoding: 2-3× speedup when memory-bound with good draft model.

  6. Prefill vs. decode: Different regimes need different optimizations.

  7. Quantization enables scale: INT4 lets you serve 4× larger models.

NoteFurther Reading
  • Pope et al. (2022). “Efficiently Scaling Transformer Inference”
  • Shazeer (2019). “Fast Transformer Decoding: One Write-Head is All You Need” (Multi-Query Attention)
  • Kwon et al. (2023). “Efficient Memory Management for Large Language Model Serving with PagedAttention” (vLLM)
  • Leviathan et al. (2023). “Fast Inference from Transformers via Speculative Decoding”
  • vLLM documentation: https://vllm.readthedocs.io/