23 Long-Context Attention
Scaling to Millions of Tokens
Claude handles 200,000 tokens. Gemini claims 1 million. GPT-4 Turbo supports 128,000.
Standard attention is O(n²). How do you attend to a million tokens?
This chapter combines locality and sparsity—the third and fourth properties from our Algebraic Framework.
Locality: Sliding window attention exploits the observation that most token interactions are local—tokens rarely need information from distant tokens. This lets us replace O(n²) full attention with O(n·w) local windows.
Sparsity: Sparse attention patterns (Longformer, BigBird) keep only a small fraction of the full attention matrix—global tokens, local windows, random connections. The full n×n matrix is sparse, so we never compute it.
Both properties break the quadratic barrier through different mechanisms: locality restricts which computations matter; sparsity encodes that restriction structurally.
23.1 The Long-Context Challenge
Attention has quadratic complexity in sequence length:
Memory and compute for self-attention:
Sequence Memory (QK^T) FLOPs
─────────────────────────────────────
1K 4 MB 2B
4K 64 MB 32B
16K 1 GB 512B
64K 16 GB 8T
256K 256 GB 128T
1M 4 TB 2P
A100 80GB: Can fit ~64K sequence with full attention matrices
H100 80GB: Same constraint applies
FlashAttention (Chapter 10) solves the memory problem through chunking—but compute is still O(n²).
This chapter explores techniques for scaling beyond what a single GPU can handle.
23.2 FlashAttention Evolution
23.2.1 FlashAttention-1 Recap
Core insight: Stream through K/V in chunks, never materialize the full n×n matrix.
Memory: O(n) instead of O(n²)
Compute: Still O(n²), but IO-efficient
23.2.2 FlashAttention-2 Improvements
FA2 (2023) improved parallelism and work partitioning:
1. Better Parallelism
FA1 parallelizes over batch and heads. FA2 also parallelizes over the sequence dimension:
FA1: Each thread block handles all of Q for one head
Limited parallelism for long sequences
FA2: Thread blocks handle Q chunks + parallelize within sequence
Better GPU utilization for long sequences
2. Reduced Non-Matmul FLOPs
FA1 has overhead from online softmax bookkeeping. FA2 reduces this:
# FA1: More register pressure from bookkeeping
# FA2: Fused and streamlined operations
# Result: 2x speedup over FA1 for long sequences3. Work Partitioning
FA2 changes the loop order for better memory access:
FA1: Outer loop over K/V blocks, inner loop over Q blocks
K/V loaded once, Q loaded repeatedly
FA2: Outer loop over Q blocks, inner loop over K/V blocks
Q loaded once, K/V streamed
Better for causal attention (can skip future K/V)
23.2.3 FlashAttention-3 (Hopper)
FA3 (2024) exploits H100-specific features:
1. Tensor Memory Accelerator (TMA)
Hardware for async memory transfers:
A100: Software-managed shared memory loading
H100: TMA handles loads automatically, overlaps with compute
Result: Better compute/memory overlap
2. FP8 Support
Native FP8 attention for 2x speedup with minimal quality loss.
3. Warp Specialization
Different warps do different work (producers vs consumers):
Traditional: All warps do same work in lockstep
FA3: Some warps load data, others compute
Pipelined execution hides latency
Performance gains:
Sequence length 16K, head dim 128:
FA2 on A100: 1.0x baseline
FA2 on H100: 1.5x (faster hardware)
FA3 on H100: 2.5x (software + hardware)
23.3 Sequence Parallelism
When a single GPU can’t fit the full sequence, split it across GPUs.
23.3.1 The Problem
For very long sequences, even FlashAttention hits limits:
Activations per layer (transformer):
= batch × seq × hidden × num_tensors
= 1 × 100K × 8192 × 4
= 3.2 GB per layer
80-layer model: 256 GB activations
Plus KV cache, gradients: Doesn't fit in 80GB
23.3.2 Naive Sequence Parallelism
Split the sequence across GPUs:
Sequence: [token_0, ..., token_99999] (100K tokens)
GPU 0: [token_0, ..., token_24999]
GPU 1: [token_25000, ..., token_49999]
GPU 2: [token_50000, ..., token_74999]
GPU 3: [token_75000, ..., token_99999]
Problem: Attention is all-to-all. Token 50000 needs to attend to token 0.
23.3.3 Ring Attention
Key insight: Stream K/V around a ring of GPUs while computing attention incrementally.
Ring Attention for 4 GPUs:
Initial state:
GPU 0: Q_0, K_0, V_0
GPU 1: Q_1, K_1, V_1
GPU 2: Q_2, K_2, V_2
GPU 3: Q_3, K_3, V_3
Step 1: Each GPU computes local attention + sends K,V to next GPU
GPU 0: Attend(Q_0, K_0, V_0) → partial_0
Send K_0, V_0 → GPU 1
GPU 1: Attend(Q_1, K_1, V_1) → partial_1
Send K_1, V_1 → GPU 2
...
Step 2: Receive K,V from previous GPU, continue attention
GPU 0: Receive K_3, V_3
Attend(Q_0, K_3, V_3) → update partial_0
GPU 1: Receive K_0, V_0
Attend(Q_1, K_0, V_0) → update partial_1
...
After N-1 steps: Each GPU has attended to all K,V
The magic: Uses FlashAttention’s online softmax trick. Partial results can be combined as new K/V arrive.
def ring_attention(Q_local, K_local, V_local, world_size, rank):
"""
Ring attention: stream K/V around ring while computing attention.
"""
# Initialize with local attention
O, lse = flash_attention_with_lse(Q_local, K_local, V_local)
# Ring communication
K_recv = torch.empty_like(K_local)
V_recv = torch.empty_like(V_local)
for step in range(world_size - 1):
# Async send K,V to next GPU
send_rank = (rank + 1) % world_size
recv_rank = (rank - 1) % world_size
send_req = dist.isend(K_local, send_rank)
dist.recv(K_recv, recv_rank)
send_req.wait()
send_req = dist.isend(V_local, send_rank)
dist.recv(V_recv, recv_rank)
send_req.wait()
# Compute attention with received K,V
O_new, lse_new = flash_attention_with_lse(Q_local, K_recv, V_recv)
# Online update (same as FlashAttention chunking)
O, lse = online_softmax_update(O, lse, O_new, lse_new)
# Prepare for next iteration
K_local, K_recv = K_recv, K_local
V_local, V_recv = V_recv, V_local
return O
def online_softmax_update(O1, lse1, O2, lse2):
"""
Combine two partial attention results.
lse = log(sum(exp(scores))) - the log-sum-exp for normalization
"""
# New log-sum-exp
lse_max = torch.maximum(lse1, lse2)
lse_new = lse_max + torch.log(
torch.exp(lse1 - lse_max) + torch.exp(lse2 - lse_max)
)
# Reweight and combine outputs
w1 = torch.exp(lse1 - lse_new)
w2 = torch.exp(lse2 - lse_new)
O_new = w1 * O1 + w2 * O2
return O_new, lse_new23.3.4 Communication Analysis
Ring attention communication pattern:
Per step: Send 2 * seq_local * head_dim * num_heads bytes
Steps: world_size - 1
Total: 2 * (N-1)/N * seq * head_dim * num_heads bytes
For 100K sequence, 4 GPUs, 8192 hidden, FP16:
= 2 * 3/4 * 100K * 8192 * 2 bytes
= 2.4 GB total communication
At 400 GB/s NVLink: 6 ms
Compared to compute: Often overlapped with attention
Key advantage: Communication is pipelined with compute. While computing attention on current K/V, receiving next K/V.
23.3.5 Causal Attention Optimization
For causal (autoregressive) attention, tokens only attend to earlier tokens:
Full attention: Causal attention:
[× × × ×] [× - - -]
[× × × ×] [× × - -]
[× × × ×] [× × × -]
[× × × ×] [× × × ×]
Ring attention optimization: Skip unnecessary K/V blocks.
GPU 0 (tokens 0-24999):
- Attends to: K_0 only (causal)
GPU 3 (tokens 75000-99999):
- Attends to: K_0, K_1, K_2, K_3 (all previous)
Result: ~50% less communication for causal attention
23.4 Context Parallelism (Megatron-LM)
Megatron’s approach to long sequences combines sequence parallelism with tensor parallelism.
23.4.1 Ulysses: Sequence Parallel Attention
Split sequence dimension, use all-to-all for attention:
def ulysses_attention(Q, K, V, seq_parallel_group):
"""
Ulysses-style sequence parallel attention.
Q, K, V are local sequence chunks.
"""
# All-gather Q, K, V across sequence parallel group
Q_full = all_gather(Q, group=seq_parallel_group)
K_full = all_gather(K, group=seq_parallel_group)
V_full = all_gather(V, group=seq_parallel_group)
# Full attention (each GPU computes full attention)
O_full = flash_attention(Q_full, K_full, V_full)
# Scatter output back
O_local = scatter(O_full, group=seq_parallel_group)
return O_localTradeoff vs Ring Attention: - Ulysses: All-gather (simpler), then local compute - Ring: Pipelined communication + compute (more complex, better overlap)
23.4.2 Combining Parallelism Dimensions
Real systems use multiple dimensions:
Example: 1M token training on 128 GPUs
Tensor Parallel: 8-way (within node)
Sequence Parallel: 4-way (across 4 nodes)
Data Parallel: 4-way (replicas)
8 × 4 × 4 = 128 GPUs
Each GPU sees: 1M / 4 / 4 = 62.5K local sequence
Still need ring attention for 62.5K on 8 GPUs
23.5 Hierarchical Attention
For extremely long contexts, hierarchical approaches reduce complexity.
23.5.1 Block-wise Attention
Divide sequence into blocks, full attention within blocks, sparse across blocks:
Block size: 4096 tokens
Sequence: 1M tokens = 244 blocks
Within block: Full O(n²) attention → O(4096²) = 16M ops per block
Across blocks: Summary vectors or sparse patterns
Total: O(n * block_size) instead of O(n²)
= O(1M * 4096) = 4B ops
vs O(1M²) = 1T ops
Examples: Longformer, BigBird, LongT5
23.5.2 Retrieval-Augmented Attention
Don’t attend to everything—retrieve relevant context:
def retrieval_attention(query, context_db, k=100):
"""
Attend only to retrieved relevant context.
"""
# Retrieve top-k relevant chunks
relevant_chunks = context_db.search(query, k=k)
# Concatenate and attend
context = torch.cat(relevant_chunks)
output = attention(query, context, context)
return outputTradeoff: Loses some information but enables arbitrary context length.
23.6 KV Cache for Long Contexts
Inference with long contexts has unique challenges.
23.6.1 The KV Cache Problem
LLaMA-70B KV cache:
= 2 × layers × seq × head_dim × 2 bytes
= 2 × 80 × 100K × 128 × 2
= 4 GB per sequence
For 8 concurrent sequences: 32 GB just for KV cache
Leaves only 48 GB for model weights and compute
23.6.2 KV Cache Compression
H2O (Heavy-Hitter Oracle): Keep only important tokens.
def h2o_kv_cache(K, V, attention_scores, budget):
"""
Keep only high-attention tokens in cache.
"""
# Compute cumulative attention each token receives
importance = attention_scores.sum(dim=-2) # Sum over queries
# Keep top tokens
keep_indices = importance.topk(budget).indices
K_compressed = K[:, keep_indices]
V_compressed = V[:, keep_indices]
return K_compressed, V_compressedStreamingLLM: Keep first tokens (attention sinks) + recent window.
def streaming_llm_cache(K, V, window_size, sink_size=4):
"""
Keep attention sinks + sliding window.
"""
K_sink = K[:, :sink_size]
V_sink = V[:, :sink_size]
K_window = K[:, -window_size:]
V_window = V[:, -window_size:]
return torch.cat([K_sink, K_window], dim=1), \
torch.cat([V_sink, V_window], dim=1)23.6.3 Quantized KV Cache
Reduce precision for KV cache:
FP16 KV cache: 4 GB per 100K sequence
INT8 KV cache: 2 GB per 100K sequence
INT4 KV cache: 1 GB per 100K sequence
Quality impact: Usually <1% perplexity degradation
23.7 Practical Considerations
23.7.1 Positional Encodings for Long Context
Standard positional encodings fail at lengths beyond training:
Learned positions: Limited to training length
Sinusoidal: Extrapolates poorly
RoPE: Better extrapolation, but still degrades
Solutions:
- RoPE with NTK-aware scaling
- YaRN (Yet another RoPE extension)
- ALiBi (relative positions via bias)
23.7.2 Memory-Efficient Training
Long-context training needs:
# Gradient checkpointing (mandatory for long contexts)
model = gradient_checkpoint(model, checkpoint_every_n_layers=2)
# Activation offloading
config = ActivationOffloadConfig(
offload_to='cpu', # or 'nvme'
prefetch=True
)
# Mixed precision with careful loss scaling
scaler = GradScaler(init_scale=2**10) # Start lower for stability23.7.3 Choosing Your Approach
Sequence Length Technique
─────────────────────────────────────────────
<16K Standard FlashAttention
16K-64K FlashAttention-2/3
64K-256K Ring Attention (multi-GPU)
256K-1M Ring Attention + KV compression
>1M Hierarchical/Retrieval + Ring Attention
23.8 Connection to Other Chapters
Chapter 10 (FlashAttention): Foundation—online softmax enables ring attention.
Chapter 13 (Distributed): Ring attention is a distributed attention pattern.
Chapter 14 (Inference): KV cache compression essential for long-context serving.
Chapter 24 (MoE): MoE + long context = complex distributed setup.
23.9 Key Takeaways
FlashAttention evolves: FA2/FA3 bring 2-4x improvements through better parallelism and hardware utilization.
Ring attention enables scale: Stream K/V around GPU ring while computing, limited by communication bandwidth.
Online softmax is key: Same trick enables both chunked attention and distributed attention.
Causal helps: Autoregressive models skip half the communication.
KV cache dominates inference: Compression and quantization are essential for long-context serving.
No free lunch: All techniques trade something—compute, quality, or complexity.
23.10 Further Reading
- Dao (2023). “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning”
- Shah et al. (2024). “FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision”
- Liu et al. (2023). “Ring Attention with Blockwise Transformers for Near-Infinite Context”
- Zhang et al. (2023). “H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models”
- Xiao et al. (2023). “Efficient Streaming Language Models with Attention Sinks”