23 Long-Context Attention
Scaling to Millions of Tokens
As of early 2024/2025 public announcements, Claude lists 200,000 tokens, Gemini lists 1 million, and GPT-4 Turbo lists 128,000. Availability varies by model version and tier, and these numbers change quickly.
Standard attention is O(n²). How do you attend to a million tokens?
This chapter combines locality and sparsity—the third and fourth properties from our Algebraic Framework.
Locality: Sliding window attention exploits the observation that most token interactions are local—tokens rarely need information from distant tokens. This lets us replace O(n²) full attention with O(n·w) local windows.
Sparsity: Sparse attention patterns (Longformer, BigBird) keep only a small fraction of the full attention matrix—global tokens, local windows, random connections. The full n×n matrix is sparse, so we never compute it.
Both properties break the quadratic barrier through different mechanisms: locality restricts which computations matter; sparsity encodes that restriction structurally.
23.1 The Long-Context Challenge
Attention has quadratic complexity in sequence length:
Memory and compute for self-attention:
Sequence Memory (QK^T) FLOPs
─────────────────────────────────────
1K 4 MB 2B
4K 64 MB 32B
16K 1 GB 512B
64K 16 GB 8T
256K 256 GB 128T
1M 4 TB 2P
These numbers assume a single attention matrix (one layer, batch size 1, FP16). Full-model memory scales with layers and heads.
A100 80GB: Can hold a handful of full attention matrices at 64K, but full-model full attention does not fit.
H100 80GB: Similar constraint unless using multi-GPU sharding or sparse/streaming attention.
FlashAttention (from the FlashAttention chapter) solves the memory problem through chunking—but compute is still O(n²).
This chapter explores techniques for scaling beyond what a single GPU can handle.
23.2 FlashAttention Evolution
23.2.1 FlashAttention-1 Recap
Core insight: Stream through K/V in chunks, never materialize the full n×n matrix.
Memory: O(n) instead of O(n²)
Compute: Still O(n²), but IO-efficient
23.2.2 FlashAttention-2 Improvements
FA2 (2023) improved parallelism and work partitioning:
1. Better Parallelism
FA1 parallelizes over batch and heads. FA2 also parallelizes over the sequence dimension:
FA1: Each thread block handles all of Q for one head
Limited parallelism for long sequences
FA2: Thread blocks handle Q chunks + parallelize within sequence
Better GPU utilization for long sequences
2. Reduced Non-Matmul FLOPs
FA1 has overhead from online softmax bookkeeping. FA2 reduces this:
# FA1: More register pressure from bookkeeping
# FA2: Fused and streamlined operations
# Result: 2x speedup over FA1 for long sequences3. Work Partitioning
FA2 changes the loop order for better memory access:
FA1: Outer loop over K/V blocks, inner loop over Q blocks
K/V loaded once, Q loaded repeatedly
FA2: Outer loop over Q blocks, inner loop over K/V blocks
Q loaded once, K/V streamed
Better for causal attention (can skip future K/V)
23.2.3 FlashAttention-3 (Hopper)
FA3 (2024) exploits H100-specific features:
1. Tensor Memory Accelerator (TMA)
Hardware for async memory transfers:
A100: Software-managed shared memory loading
H100: TMA handles loads automatically, overlaps with compute
Result: Better compute/memory overlap
2. FP8 Support
Native FP8 attention for 2x speedup with minimal quality loss.
3. Warp Specialization
Different warps do different work (producers vs consumers):
Traditional: All warps do same work in lockstep
FA3: Some warps load data, others compute
Pipelined execution hides latency
Performance gains:
Sequence length 16K, head dim 128:
FA2 on A100: 1.0x baseline
FA2 on H100: 1.5x (faster hardware)
FA3 on H100: 2.5x (software + hardware)
23.3 Sequence Parallelism
When a single GPU can’t fit the full sequence, split it across GPUs.
23.3.1 The Problem
For very long sequences, even FlashAttention hits limits:
Activations per layer (transformer):
= batch × seq × hidden × num_tensors
= 1 × 100K × 8192 × 4
= 3.2 GB per layer
80-layer model: 256 GB activations
Plus KV cache, gradients: Doesn't fit in 80GB
23.3.2 Naive Sequence Parallelism
Split the sequence across GPUs:
Sequence: [token_0, ..., token_99999] (100K tokens)
GPU 0: [token_0, ..., token_24999]
GPU 1: [token_25000, ..., token_49999]
GPU 2: [token_50000, ..., token_74999]
GPU 3: [token_75000, ..., token_99999]
Problem: Attention is all-to-all. Token 50000 needs to attend to token 0.
23.3.3 Ring Attention
Key insight: Stream K/V around a ring of GPUs while computing attention incrementally.
Ring Attention for 4 GPUs:
Initial state:
GPU 0: Q_0, K_0, V_0
GPU 1: Q_1, K_1, V_1
GPU 2: Q_2, K_2, V_2
GPU 3: Q_3, K_3, V_3
Step 1: Each GPU computes local attention + sends K,V to next GPU
GPU 0: Attend(Q_0, K_0, V_0) → partial_0
Send K_0, V_0 → GPU 1
GPU 1: Attend(Q_1, K_1, V_1) → partial_1
Send K_1, V_1 → GPU 2
...
Step 2: Receive K,V from previous GPU, continue attention
GPU 0: Receive K_3, V_3
Attend(Q_0, K_3, V_3) → update partial_0
GPU 1: Receive K_0, V_0
Attend(Q_1, K_0, V_0) → update partial_1
...
After N-1 steps: Each GPU has attended to all K,V
The magic: Uses FlashAttention’s online softmax trick. Partial results can be combined as new K/V arrive.
def ring_attention(Q_local, K_local, V_local, world_size, rank):
"""
Ring attention: stream K/V around ring while computing attention.
"""
# Initialize with local attention
O, lse = flash_attention_with_lse(Q_local, K_local, V_local)
# Ring communication
K_recv = torch.empty_like(K_local)
V_recv = torch.empty_like(V_local)
for step in range(world_size - 1):
# Async send K,V to next GPU
send_rank = (rank + 1) % world_size
recv_rank = (rank - 1) % world_size
send_req = dist.isend(K_local, send_rank)
dist.recv(K_recv, recv_rank)
send_req.wait()
send_req = dist.isend(V_local, send_rank)
dist.recv(V_recv, recv_rank)
send_req.wait()
# Compute attention with received K,V
O_new, lse_new = flash_attention_with_lse(Q_local, K_recv, V_recv)
# Online update (same as FlashAttention chunking)
O, lse = online_softmax_update(O, lse, O_new, lse_new)
# Prepare for next iteration
K_local, K_recv = K_recv, K_local
V_local, V_recv = V_recv, V_local
return O
def online_softmax_update(O1, lse1, O2, lse2):
"""
Combine two partial attention results.
lse = log(sum(exp(scores))) - the log-sum-exp for normalization
"""
# New log-sum-exp
lse_max = torch.maximum(lse1, lse2)
lse_new = lse_max + torch.log(
torch.exp(lse1 - lse_max) + torch.exp(lse2 - lse_max)
)
# Reweight and combine outputs
w1 = torch.exp(lse1 - lse_new)
w2 = torch.exp(lse2 - lse_new)
O_new = w1 * O1 + w2 * O2
return O_new, lse_new23.3.4 Communication Analysis
Ring attention communication pattern:
Per step: Send 2 * seq_local * head_dim * num_heads bytes
Steps: world_size - 1
Total: 2 * (N-1)/N * seq * head_dim * num_heads bytes
For 100K sequence, 4 GPUs, 8192 hidden, FP16:
= 2 * 3/4 * 100K * 8192 * 2 bytes
= 2.4 GB total communication
At 400 GB/s NVLink: 6 ms
Compared to compute: Often overlapped with attention
Key advantage: Communication is pipelined with compute. While computing attention on current K/V, receiving next K/V.
23.3.5 Causal Attention Optimization
For causal (autoregressive) attention, tokens only attend to earlier tokens:
Full attention: Causal attention:
[× × × ×] [× - - -]
[× × × ×] [× × - -]
[× × × ×] [× × × -]
[× × × ×] [× × × ×]
Ring attention optimization: Skip unnecessary K/V blocks.
GPU 0 (tokens 0-24999):
- Attends to: K_0 only (causal)
GPU 3 (tokens 75000-99999):
- Attends to: K_0, K_1, K_2, K_3 (all previous)
Result: ~50% less communication for causal attention
23.4 Context Parallelism (Megatron-LM)
Megatron’s approach to long sequences combines sequence parallelism with tensor parallelism.
23.4.1 Ulysses: Sequence Parallel Attention
Split sequence dimension, use all-to-all for attention:
def ulysses_attention(Q, K, V, seq_parallel_group):
"""
Ulysses-style sequence parallel attention.
Q, K, V are local sequence chunks.
"""
# All-gather Q, K, V across sequence parallel group
Q_full = all_gather(Q, group=seq_parallel_group)
K_full = all_gather(K, group=seq_parallel_group)
V_full = all_gather(V, group=seq_parallel_group)
# Full attention (each GPU computes full attention)
O_full = flash_attention(Q_full, K_full, V_full)
# Scatter output back
O_local = scatter(O_full, group=seq_parallel_group)
return O_localTradeoff vs Ring Attention: - Ulysses: All-gather (simpler), then local compute - Ring: Pipelined communication + compute (more complex, better overlap)
23.4.2 Combining Parallelism Dimensions
Real systems use multiple dimensions:
Example: 1M token training on 128 GPUs
Tensor Parallel: 8-way (within node)
Sequence Parallel: 4-way (across 4 nodes)
Data Parallel: 4-way (replicas)
8 × 4 × 4 = 128 GPUs
Each GPU sees: 1M / 4 / 4 = 62.5K local sequence
Still need ring attention for 62.5K on 8 GPUs
23.5 Hierarchical Attention
For extremely long contexts, hierarchical approaches reduce complexity.
23.5.1 Block-wise Attention
Divide sequence into blocks, full attention within blocks, sparse across blocks:
Block size: 4096 tokens
Sequence: 1M tokens = 244 blocks
Within block: Full O(n²) attention → O(4096²) = 16M ops per block
Across blocks: Summary vectors or sparse patterns
Total: O(n * block_size) instead of O(n²)
= O(1M * 4096) = 4B ops
vs O(1M²) = 1T ops
Examples: Longformer, BigBird, LongT5
23.5.2 Retrieval-Augmented Attention
Don’t attend to everything—retrieve relevant context:
def retrieval_attention(query, context_db, k=100):
"""
Attend only to retrieved relevant context.
"""
# Retrieve top-k relevant chunks
relevant_chunks = context_db.search(query, k=k)
# Concatenate and attend
context = torch.cat(relevant_chunks)
output = attention(query, context, context)
return outputTradeoff: Loses some information but enables arbitrary context length.
23.6 KV Cache for Long Contexts
Inference with long contexts has unique challenges.
23.6.1 The KV Cache Problem
LLaMA-70B KV cache:
= 2 × layers × seq × head_dim × 2 bytes
= 2 × 80 × 100K × 128 × 2
= 4 GB per sequence
For 8 concurrent sequences: 32 GB just for KV cache
Leaves only 48 GB for model weights and compute
23.6.2 KV Cache Compression
H2O (Heavy-Hitter Oracle): Keep only important tokens.
def h2o_kv_cache(K, V, attention_scores, budget):
"""
Keep only high-attention tokens in cache.
"""
# Compute cumulative attention each token receives
importance = attention_scores.sum(dim=-2) # Sum over queries
# Keep top tokens
keep_indices = importance.topk(budget).indices
K_compressed = K[:, keep_indices]
V_compressed = V[:, keep_indices]
return K_compressed, V_compressedStreamingLLM: Keep first tokens (attention sinks) + recent window.
def streaming_llm_cache(K, V, window_size, sink_size=4):
"""
Keep attention sinks + sliding window.
"""
K_sink = K[:, :sink_size]
V_sink = V[:, :sink_size]
K_window = K[:, -window_size:]
V_window = V[:, -window_size:]
return torch.cat([K_sink, K_window], dim=1), \
torch.cat([V_sink, V_window], dim=1)23.6.3 Quantized KV Cache
Reduce precision for KV cache:
FP16 KV cache: 4 GB per 100K sequence
INT8 KV cache: 2 GB per 100K sequence
INT4 KV cache: 1 GB per 100K sequence
Quality impact: Usually <1% perplexity degradation
23.7 Practical Considerations
23.7.1 Positional Encodings for Long Context
Without positional encodings, self-attention is permutation-equivariant—the same output for any ordering of tokens. For sequences, we need position to matter. Positional encodings break this symmetry intentionally.
But for long contexts, the choice of positional encoding becomes critical.
23.7.1.1 The Problem with Traditional Approaches
Learned positions:
- Fixed vocabulary of positions (e.g., 0-2047)
- Cannot extrapolate beyond training length
- Must retrain for longer contexts
Sinusoidal positions:
- Can represent any position mathematically
- But attention patterns learned on positions 0-2047
don't generalize to position 100,000
- The model hasn't learned what "far away" means
The fundamental issue: traditional position encodings encode absolute position. But attention really cares about relative position—how far apart are two tokens?
23.7.1.2 Rotary Position Embeddings (RoPE)
RoPE solves this elegantly using rotation matrices. The key insight: rotations naturally encode relative position through their composition properties.
The Mathematical Foundation
A 2D rotation matrix has the form:
\[R(\theta) = \begin{pmatrix} \cos\theta & -\sin\theta \\ \sin\theta & \cos\theta \end{pmatrix}\]
Rotation matrices have special properties that make them perfect for positional encoding:
| Property | Mathematical Form | Why It Enables RoPE |
|---|---|---|
| Orthogonality | \(R^T R = I\) | Preserves vector norms and dot products |
| Relative composition | \(R_m^T R_n = R_{n-m}\) | Relative position from absolute |
| Commutativity | \(R_\theta R_\phi = R_{\theta+\phi}\) | Positions compose cleanly |
The crucial property is the second one. When we compute attention between positions \(m\) and \(n\):
\[(R_m q)^T (R_n k) = q^T R_m^T R_n k = q^T R_{n-m} k\]
The absolute positions \(m\) and \(n\) vanish—only the relative position \(n-m\) remains. This is exactly what we want for attention.
Applying Rotation to Embeddings
RoPE treats consecutive pairs of embedding dimensions as 2D vectors and rotates each pair:
def rope_embedding(x, positions, base=10000):
"""
Apply Rotary Position Embedding to input tensor.
x: [batch, seq, heads, head_dim]
positions: [seq] - position indices
"""
head_dim = x.shape[-1]
# Different frequency for each dimension pair
# Lower dimensions = lower frequency = longer wavelength
dim_pairs = head_dim // 2
freqs = 1.0 / (base ** (torch.arange(0, dim_pairs) / dim_pairs))
# Compute rotation angles: position × frequency
# [seq, dim_pairs]
angles = positions.unsqueeze(-1) * freqs.unsqueeze(0)
# Apply rotation to each pair of dimensions
x_pairs = x.view(*x.shape[:-1], dim_pairs, 2)
cos = torch.cos(angles)
sin = torch.sin(angles)
# Rotation: [x, y] → [x·cos - y·sin, x·sin + y·cos]
x_rot = torch.stack([
x_pairs[..., 0] * cos - x_pairs[..., 1] * sin,
x_pairs[..., 0] * sin + x_pairs[..., 1] * cos
], dim=-1)
return x_rot.view(*x.shape)Multi-Frequency Encoding
Each dimension pair uses a different rotation frequency:
Dimension pair 0: θ = position × (1 / base^0) = position
Dimension pair 1: θ = position × (1 / base^(1/d))
Dimension pair 2: θ = position × (1 / base^(2/d))
...
Dimension pair d: θ = position × (1 / base^1) = position / base
Low dimensions: Fast rotation, fine position detail
High dimensions: Slow rotation, coarse position info
This is analogous to Fourier features or the sinusoidal encoding—but applied through rotations that preserve the relative-position property.
Why RoPE Works Better for Long Contexts
The relative-position formulation helps RoPE extrapolate:
Traditional absolute position:
- Train on positions 0-4096
- Position 10000 is "unknown" to the model
RoPE relative position:
- Train on relative distances -4096 to +4096
- At position 10000, nearby tokens still have
familiar relative distances
- Long-range attention may degrade, but
local attention remains well-calibrated
23.7.1.3 RoPE Extensions for Extreme Lengths
Even RoPE degrades beyond training length. Several extensions address this:
NTK-Aware Scaling
Key insight: don’t just interpolate positions—adjust the frequency base.
def ntk_aware_scaling(base, scale_factor, dim):
"""
NTK-aware interpolation: adjust base frequency for length extension.
Instead of: position → position / scale_factor (interpolation)
Do: base → base × scale_factor^(dim/(dim-2))
"""
# This preserves relative position resolution better
# than naive position interpolation
alpha = scale_factor ** (dim / (dim - 2))
return base * alphaThe intuition: NTK (Neural Tangent Kernel) theory suggests the frequency spectrum matters. Naive interpolation compresses high frequencies too aggressively; NTK-aware scaling maintains the frequency distribution better.
YaRN (Yet another RoPE extensioN)
YaRN combines multiple techniques:
- NTK-aware interpolation for the base frequency
- Attention scaling to compensate for longer contexts
- Dimension-dependent interpolation (different scaling per frequency)
def yarn_scale_factor(dim_idx, dim, scale, alpha=1, beta=32):
"""
YaRN: dimension-dependent interpolation factor.
Low-frequency dimensions: interpolate more aggressively
High-frequency dimensions: preserve more (local info)
"""
# Linear ramp from 0 to 1 based on dimension
ramp = (dim_idx / dim - alpha) / (beta - alpha)
ramp = min(1, max(0, ramp))
# Blend between no scaling and full scaling
return 1 / (1 - ramp + ramp * scale)ALiBi (Attention with Linear Biases)
A different approach: don’t modify embeddings, add position bias directly to attention scores.
def alibi_attention(Q, K, V, num_heads):
"""
ALiBi: add linear position bias to attention.
"""
seq_len = Q.shape[1]
# Relative position matrix
positions = torch.arange(seq_len)
relative_pos = positions.unsqueeze(0) - positions.unsqueeze(1)
# Different slope per head (geometric sequence)
slopes = 2 ** (-8 / num_heads * torch.arange(1, num_heads + 1))
# Attention scores with position bias
scores = Q @ K.transpose(-1, -2) / sqrt(d)
# Add negative bias for distant positions
# (encourages attending to nearby tokens)
for h, slope in enumerate(slopes):
scores[:, h] -= slope * torch.abs(relative_pos)
return softmax(scores, dim=-1) @ VALiBi’s advantages: - No learned parameters for position - Naturally handles any length (bias is mathematically defined) - Different heads specialize in different ranges
Choosing Your Approach
Approach Extrapolation Compute Cost Compatibility
─────────────────────────────────────────────────────────
Learned None Cheap N/A beyond train length
Sinusoidal Poor Cheap Degrades badly
RoPE Moderate Cheap Good local, degrades long
RoPE+NTK Good Cheap Minor fine-tuning helps
YaRN Very Good Cheap Best with fine-tuning
ALiBi Excellent Cheap No fine-tuning needed
RoPE is a beautiful example of exploiting rotation symmetry (see Chapter: Symmetry). The SO(2) rotation group has exactly the properties needed for relative positional encoding:
- Orthogonality preserves the geometry of attention
- Group composition (\(R_m^T R_n = R_{n-m}\)) naturally encodes relative position
- The continuous nature of rotation enables smooth interpolation
This is why rotation matrices specifically—not arbitrary transformations—work so well. The algebraic structure of the rotation group matches the structure of what we want to compute.
23.7.2 Memory-Efficient Training
Long-context training needs:
# Gradient checkpointing (mandatory for long contexts)
model = gradient_checkpoint(model, checkpoint_every_n_layers=2)
# Activation offloading
config = ActivationOffloadConfig(
offload_to='cpu', # or 'nvme'
prefetch=True
)
# Mixed precision with careful loss scaling
scaler = GradScaler(init_scale=2**10) # Start lower for stability23.7.3 Choosing Your Approach
Sequence Length Technique
─────────────────────────────────────────────
<16K Standard FlashAttention
16K-64K FlashAttention-2/3
64K-256K Ring Attention (multi-GPU)
256K-1M Ring Attention + KV compression
>1M Hierarchical/Retrieval + Ring Attention
23.8 Connections
FlashAttention: Foundation—online softmax enables ring attention.
Distributed: Ring attention is a distributed attention pattern.
Inference: KV cache compression essential for long-context serving.
MoE: MoE + long context = complex distributed setup.
23.9 Key Takeaways
FlashAttention evolves: FA2/FA3 bring 2-4x improvements through better parallelism and hardware utilization.
Ring attention enables scale: Stream K/V around GPU ring while computing, limited by communication bandwidth.
Online softmax is key: Same trick enables both chunked attention and distributed attention.
Causal helps: Autoregressive models skip half the communication.
KV cache dominates inference: Compression and quantization are essential for long-context serving.
No free lunch: All techniques trade something—compute, quality, or complexity.
The accompanying notebook lets you:
- Implement ring attention from scratch
- Visualize communication patterns
- Compare FA1 vs FA2 vs FA3 performance
- Experiment with KV cache compression
Notebook support for this chapter is in progress. For now, run the examples locally and benchmark context-length trade-offs on your hardware.
23.10 Further Reading
- Dao (2023). “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning”
- Shah et al. (2024). “FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision”
- Liu et al. (2023). “Ring Attention with Blockwise Transformers for Near-Infinite Context”
- Zhang et al. (2023). “H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models”
- Xiao et al. (2023). “Efficient Streaming Language Models with Attention Sinks”
- Su et al. (2021). “RoFormer: Enhanced Transformer with Rotary Position Embedding”
- Peng et al. (2023). “YaRN: Efficient Context Window Extension of Large Language Models”
- Press et al. (2022). “Train Short, Test Long: Attention with Linear Biases Enables Input Length Generalization” (ALiBi)
- bloc97 (2023). “NTK-Aware Scaled RoPE allows LLaMA models to have extended context” (Reddit post that introduced NTK-aware scaling)