23  Long-Context Attention

Scaling to Millions of Tokens


Claude handles 200,000 tokens. Gemini claims 1 million. GPT-4 Turbo supports 128,000.

Standard attention is O(n²). How do you attend to a million tokens?

NoteProperty Spotlight: Locality & Sparsity

This chapter combines locality and sparsity—the third and fourth properties from our Algebraic Framework.

Locality: Sliding window attention exploits the observation that most token interactions are local—tokens rarely need information from distant tokens. This lets us replace O(n²) full attention with O(n·w) local windows.

Sparsity: Sparse attention patterns (Longformer, BigBird) keep only a small fraction of the full attention matrix—global tokens, local windows, random connections. The full n×n matrix is sparse, so we never compute it.

Both properties break the quadratic barrier through different mechanisms: locality restricts which computations matter; sparsity encodes that restriction structurally.

23.1 The Long-Context Challenge

Attention has quadratic complexity in sequence length:

Memory and compute for self-attention:

Sequence    Memory (QK^T)    FLOPs
─────────────────────────────────────
1K          4 MB             2B
4K          64 MB            32B
16K         1 GB             512B
64K         16 GB            8T
256K        256 GB           128T
1M          4 TB             2P

A100 80GB: Can fit ~64K sequence with full attention matrices
H100 80GB: Same constraint applies

FlashAttention (Chapter 10) solves the memory problem through chunking—but compute is still O(n²).

This chapter explores techniques for scaling beyond what a single GPU can handle.

23.2 FlashAttention Evolution

23.2.1 FlashAttention-1 Recap

Core insight: Stream through K/V in chunks, never materialize the full n×n matrix.

Memory: O(n) instead of O(n²)
Compute: Still O(n²), but IO-efficient

23.2.2 FlashAttention-2 Improvements

FA2 (2023) improved parallelism and work partitioning:

1. Better Parallelism

FA1 parallelizes over batch and heads. FA2 also parallelizes over the sequence dimension:

FA1: Each thread block handles all of Q for one head
     Limited parallelism for long sequences

FA2: Thread blocks handle Q chunks + parallelize within sequence
     Better GPU utilization for long sequences

2. Reduced Non-Matmul FLOPs

FA1 has overhead from online softmax bookkeeping. FA2 reduces this:

# FA1: More register pressure from bookkeeping
# FA2: Fused and streamlined operations

# Result: 2x speedup over FA1 for long sequences

3. Work Partitioning

FA2 changes the loop order for better memory access:

FA1: Outer loop over K/V blocks, inner loop over Q blocks
     K/V loaded once, Q loaded repeatedly

FA2: Outer loop over Q blocks, inner loop over K/V blocks
     Q loaded once, K/V streamed
     Better for causal attention (can skip future K/V)

23.2.3 FlashAttention-3 (Hopper)

FA3 (2024) exploits H100-specific features:

1. Tensor Memory Accelerator (TMA)

Hardware for async memory transfers:

A100: Software-managed shared memory loading
H100: TMA handles loads automatically, overlaps with compute

Result: Better compute/memory overlap

2. FP8 Support

Native FP8 attention for 2x speedup with minimal quality loss.

3. Warp Specialization

Different warps do different work (producers vs consumers):

Traditional: All warps do same work in lockstep
FA3: Some warps load data, others compute
     Pipelined execution hides latency

Performance gains:

Sequence length 16K, head dim 128:
  FA2 on A100: 1.0x baseline
  FA2 on H100: 1.5x (faster hardware)
  FA3 on H100: 2.5x (software + hardware)

23.3 Sequence Parallelism

When a single GPU can’t fit the full sequence, split it across GPUs.

23.3.1 The Problem

For very long sequences, even FlashAttention hits limits:

Activations per layer (transformer):
  = batch × seq × hidden × num_tensors
  = 1 × 100K × 8192 × 4
  = 3.2 GB per layer

80-layer model: 256 GB activations
Plus KV cache, gradients: Doesn't fit in 80GB

23.3.2 Naive Sequence Parallelism

Split the sequence across GPUs:

Sequence: [token_0, ..., token_99999] (100K tokens)

GPU 0: [token_0, ..., token_24999]
GPU 1: [token_25000, ..., token_49999]
GPU 2: [token_50000, ..., token_74999]
GPU 3: [token_75000, ..., token_99999]

Problem: Attention is all-to-all. Token 50000 needs to attend to token 0.

23.3.3 Ring Attention

Key insight: Stream K/V around a ring of GPUs while computing attention incrementally.

Ring Attention for 4 GPUs:

Initial state:
  GPU 0: Q_0, K_0, V_0
  GPU 1: Q_1, K_1, V_1
  GPU 2: Q_2, K_2, V_2
  GPU 3: Q_3, K_3, V_3

Step 1: Each GPU computes local attention + sends K,V to next GPU
  GPU 0: Attend(Q_0, K_0, V_0) → partial_0
         Send K_0, V_0 → GPU 1
  GPU 1: Attend(Q_1, K_1, V_1) → partial_1
         Send K_1, V_1 → GPU 2
  ...

Step 2: Receive K,V from previous GPU, continue attention
  GPU 0: Receive K_3, V_3
         Attend(Q_0, K_3, V_3) → update partial_0
  GPU 1: Receive K_0, V_0
         Attend(Q_1, K_0, V_0) → update partial_1
  ...

After N-1 steps: Each GPU has attended to all K,V

The magic: Uses FlashAttention’s online softmax trick. Partial results can be combined as new K/V arrive.

def ring_attention(Q_local, K_local, V_local, world_size, rank):
    """
    Ring attention: stream K/V around ring while computing attention.
    """
    # Initialize with local attention
    O, lse = flash_attention_with_lse(Q_local, K_local, V_local)

    # Ring communication
    K_recv = torch.empty_like(K_local)
    V_recv = torch.empty_like(V_local)

    for step in range(world_size - 1):
        # Async send K,V to next GPU
        send_rank = (rank + 1) % world_size
        recv_rank = (rank - 1) % world_size

        send_req = dist.isend(K_local, send_rank)
        dist.recv(K_recv, recv_rank)
        send_req.wait()

        send_req = dist.isend(V_local, send_rank)
        dist.recv(V_recv, recv_rank)
        send_req.wait()

        # Compute attention with received K,V
        O_new, lse_new = flash_attention_with_lse(Q_local, K_recv, V_recv)

        # Online update (same as FlashAttention chunking)
        O, lse = online_softmax_update(O, lse, O_new, lse_new)

        # Prepare for next iteration
        K_local, K_recv = K_recv, K_local
        V_local, V_recv = V_recv, V_local

    return O

def online_softmax_update(O1, lse1, O2, lse2):
    """
    Combine two partial attention results.

    lse = log(sum(exp(scores))) - the log-sum-exp for normalization
    """
    # New log-sum-exp
    lse_max = torch.maximum(lse1, lse2)
    lse_new = lse_max + torch.log(
        torch.exp(lse1 - lse_max) + torch.exp(lse2 - lse_max)
    )

    # Reweight and combine outputs
    w1 = torch.exp(lse1 - lse_new)
    w2 = torch.exp(lse2 - lse_new)
    O_new = w1 * O1 + w2 * O2

    return O_new, lse_new

23.3.4 Communication Analysis

Ring attention communication pattern:

Per step: Send 2 * seq_local * head_dim * num_heads bytes
Steps: world_size - 1
Total: 2 * (N-1)/N * seq * head_dim * num_heads bytes

For 100K sequence, 4 GPUs, 8192 hidden, FP16:
  = 2 * 3/4 * 100K * 8192 * 2 bytes
  = 2.4 GB total communication

At 400 GB/s NVLink: 6 ms
Compared to compute: Often overlapped with attention

Key advantage: Communication is pipelined with compute. While computing attention on current K/V, receiving next K/V.

23.3.5 Causal Attention Optimization

For causal (autoregressive) attention, tokens only attend to earlier tokens:

Full attention:              Causal attention:
[× × × ×]                    [× - - -]
[× × × ×]                    [× × - -]
[× × × ×]                    [× × × -]
[× × × ×]                    [× × × ×]

Ring attention optimization: Skip unnecessary K/V blocks.

GPU 0 (tokens 0-24999):
  - Attends to: K_0 only (causal)

GPU 3 (tokens 75000-99999):
  - Attends to: K_0, K_1, K_2, K_3 (all previous)

Result: ~50% less communication for causal attention

23.4 Context Parallelism (Megatron-LM)

Megatron’s approach to long sequences combines sequence parallelism with tensor parallelism.

23.4.1 Ulysses: Sequence Parallel Attention

Split sequence dimension, use all-to-all for attention:

def ulysses_attention(Q, K, V, seq_parallel_group):
    """
    Ulysses-style sequence parallel attention.

    Q, K, V are local sequence chunks.
    """
    # All-gather Q, K, V across sequence parallel group
    Q_full = all_gather(Q, group=seq_parallel_group)
    K_full = all_gather(K, group=seq_parallel_group)
    V_full = all_gather(V, group=seq_parallel_group)

    # Full attention (each GPU computes full attention)
    O_full = flash_attention(Q_full, K_full, V_full)

    # Scatter output back
    O_local = scatter(O_full, group=seq_parallel_group)

    return O_local

Tradeoff vs Ring Attention: - Ulysses: All-gather (simpler), then local compute - Ring: Pipelined communication + compute (more complex, better overlap)

23.4.2 Combining Parallelism Dimensions

Real systems use multiple dimensions:

Example: 1M token training on 128 GPUs

Tensor Parallel: 8-way (within node)
Sequence Parallel: 4-way (across 4 nodes)
Data Parallel: 4-way (replicas)

8 × 4 × 4 = 128 GPUs

Each GPU sees: 1M / 4 / 4 = 62.5K local sequence
Still need ring attention for 62.5K on 8 GPUs

23.5 Hierarchical Attention

For extremely long contexts, hierarchical approaches reduce complexity.

23.5.1 Block-wise Attention

Divide sequence into blocks, full attention within blocks, sparse across blocks:

Block size: 4096 tokens
Sequence: 1M tokens = 244 blocks

Within block: Full O(n²) attention → O(4096²) = 16M ops per block
Across blocks: Summary vectors or sparse patterns

Total: O(n * block_size) instead of O(n²)
     = O(1M * 4096) = 4B ops
     vs O(1M²) = 1T ops

Examples: Longformer, BigBird, LongT5

23.5.2 Retrieval-Augmented Attention

Don’t attend to everything—retrieve relevant context:

def retrieval_attention(query, context_db, k=100):
    """
    Attend only to retrieved relevant context.
    """
    # Retrieve top-k relevant chunks
    relevant_chunks = context_db.search(query, k=k)

    # Concatenate and attend
    context = torch.cat(relevant_chunks)
    output = attention(query, context, context)

    return output

Tradeoff: Loses some information but enables arbitrary context length.

23.6 KV Cache for Long Contexts

Inference with long contexts has unique challenges.

23.6.1 The KV Cache Problem

LLaMA-70B KV cache:
  = 2 × layers × seq × head_dim × 2 bytes
  = 2 × 80 × 100K × 128 × 2
  = 4 GB per sequence

For 8 concurrent sequences: 32 GB just for KV cache
Leaves only 48 GB for model weights and compute

23.6.2 KV Cache Compression

H2O (Heavy-Hitter Oracle): Keep only important tokens.

def h2o_kv_cache(K, V, attention_scores, budget):
    """
    Keep only high-attention tokens in cache.
    """
    # Compute cumulative attention each token receives
    importance = attention_scores.sum(dim=-2)  # Sum over queries

    # Keep top tokens
    keep_indices = importance.topk(budget).indices
    K_compressed = K[:, keep_indices]
    V_compressed = V[:, keep_indices]

    return K_compressed, V_compressed

StreamingLLM: Keep first tokens (attention sinks) + recent window.

def streaming_llm_cache(K, V, window_size, sink_size=4):
    """
    Keep attention sinks + sliding window.
    """
    K_sink = K[:, :sink_size]
    V_sink = V[:, :sink_size]

    K_window = K[:, -window_size:]
    V_window = V[:, -window_size:]

    return torch.cat([K_sink, K_window], dim=1), \
           torch.cat([V_sink, V_window], dim=1)

23.6.3 Quantized KV Cache

Reduce precision for KV cache:

FP16 KV cache: 4 GB per 100K sequence
INT8 KV cache: 2 GB per 100K sequence
INT4 KV cache: 1 GB per 100K sequence

Quality impact: Usually <1% perplexity degradation

23.7 Practical Considerations

23.7.1 Positional Encodings for Long Context

Standard positional encodings fail at lengths beyond training:

Learned positions: Limited to training length
Sinusoidal: Extrapolates poorly
RoPE: Better extrapolation, but still degrades

Solutions:
- RoPE with NTK-aware scaling
- YaRN (Yet another RoPE extension)
- ALiBi (relative positions via bias)

23.7.2 Memory-Efficient Training

Long-context training needs:

# Gradient checkpointing (mandatory for long contexts)
model = gradient_checkpoint(model, checkpoint_every_n_layers=2)

# Activation offloading
config = ActivationOffloadConfig(
    offload_to='cpu',  # or 'nvme'
    prefetch=True
)

# Mixed precision with careful loss scaling
scaler = GradScaler(init_scale=2**10)  # Start lower for stability

23.7.3 Choosing Your Approach

Sequence Length    Technique
─────────────────────────────────────────────
<16K              Standard FlashAttention
16K-64K           FlashAttention-2/3
64K-256K          Ring Attention (multi-GPU)
256K-1M           Ring Attention + KV compression
>1M               Hierarchical/Retrieval + Ring Attention

23.8 Connection to Other Chapters

Chapter 10 (FlashAttention): Foundation—online softmax enables ring attention.

Chapter 13 (Distributed): Ring attention is a distributed attention pattern.

Chapter 14 (Inference): KV cache compression essential for long-context serving.

Chapter 24 (MoE): MoE + long context = complex distributed setup.

23.9 Key Takeaways

  1. FlashAttention evolves: FA2/FA3 bring 2-4x improvements through better parallelism and hardware utilization.

  2. Ring attention enables scale: Stream K/V around GPU ring while computing, limited by communication bandwidth.

  3. Online softmax is key: Same trick enables both chunked attention and distributed attention.

  4. Causal helps: Autoregressive models skip half the communication.

  5. KV cache dominates inference: Compression and quantization are essential for long-context serving.

  6. No free lunch: All techniques trade something—compute, quality, or complexity.

NoteTry It Yourself

The accompanying notebook lets you:

  • Implement ring attention from scratch
  • Visualize communication patterns
  • Compare FA1 vs FA2 vs FA3 performance
  • Experiment with KV cache compression

Open In Colab

23.10 Further Reading

  • Dao (2023). “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning”
  • Shah et al. (2024). “FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision”
  • Liu et al. (2023). “Ring Attention with Blockwise Transformers for Near-Infinite Context”
  • Zhang et al. (2023). “H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models”
  • Xiao et al. (2023). “Efficient Streaming Language Models with Attention Sinks”