23  Long-Context Attention

Scaling to Millions of Tokens

As of early 2024/2025 public announcements, Claude lists 200,000 tokens, Gemini lists 1 million, and GPT-4 Turbo lists 128,000. Availability varies by model version and tier, and these numbers change quickly.

Standard attention is O(n²). How do you attend to a million tokens?

NoteProperty Spotlight: Locality & Sparsity

This chapter combines locality and sparsity—the third and fourth properties from our Algebraic Framework.

Locality: Sliding window attention exploits the observation that most token interactions are local—tokens rarely need information from distant tokens. This lets us replace O(n²) full attention with O(n·w) local windows.

Sparsity: Sparse attention patterns (Longformer, BigBird) keep only a small fraction of the full attention matrix—global tokens, local windows, random connections. The full n×n matrix is sparse, so we never compute it.

Both properties break the quadratic barrier through different mechanisms: locality restricts which computations matter; sparsity encodes that restriction structurally.

23.1 The Long-Context Challenge

Attention has quadratic complexity in sequence length:

Memory and compute for self-attention:

Sequence    Memory (QK^T)    FLOPs
─────────────────────────────────────
1K          4 MB             2B
4K          64 MB            32B
16K         1 GB             512B
64K         16 GB            8T
256K        256 GB           128T
1M          4 TB             2P

These numbers assume a single attention matrix (one layer, batch size 1, FP16). Full-model memory scales with layers and heads.

A100 80GB: Can hold a handful of full attention matrices at 64K, but full-model full attention does not fit.
H100 80GB: Similar constraint unless using multi-GPU sharding or sparse/streaming attention.

FlashAttention (from the FlashAttention chapter) solves the memory problem through chunking—but compute is still O(n²).

This chapter explores techniques for scaling beyond what a single GPU can handle.

23.2 FlashAttention Evolution

23.2.1 FlashAttention-1 Recap

Core insight: Stream through K/V in chunks, never materialize the full n×n matrix.

Memory: O(n) instead of O(n²)
Compute: Still O(n²), but IO-efficient

23.2.2 FlashAttention-2 Improvements

FA2 (2023) improved parallelism and work partitioning:

1. Better Parallelism

FA1 parallelizes over batch and heads. FA2 also parallelizes over the sequence dimension:

FA1: Each thread block handles all of Q for one head
     Limited parallelism for long sequences

FA2: Thread blocks handle Q chunks + parallelize within sequence
     Better GPU utilization for long sequences

2. Reduced Non-Matmul FLOPs

FA1 has overhead from online softmax bookkeeping. FA2 reduces this:

# FA1: More register pressure from bookkeeping
# FA2: Fused and streamlined operations

# Result: 2x speedup over FA1 for long sequences

3. Work Partitioning

FA2 changes the loop order for better memory access:

FA1: Outer loop over K/V blocks, inner loop over Q blocks
     K/V loaded once, Q loaded repeatedly

FA2: Outer loop over Q blocks, inner loop over K/V blocks
     Q loaded once, K/V streamed
     Better for causal attention (can skip future K/V)

23.2.3 FlashAttention-3 (Hopper)

FA3 (2024) exploits H100-specific features:

1. Tensor Memory Accelerator (TMA)

Hardware for async memory transfers:

A100: Software-managed shared memory loading
H100: TMA handles loads automatically, overlaps with compute

Result: Better compute/memory overlap

2. FP8 Support

Native FP8 attention for 2x speedup with minimal quality loss.

3. Warp Specialization

Different warps do different work (producers vs consumers):

Traditional: All warps do same work in lockstep
FA3: Some warps load data, others compute
     Pipelined execution hides latency

Performance gains:

Sequence length 16K, head dim 128:
  FA2 on A100: 1.0x baseline
  FA2 on H100: 1.5x (faster hardware)
  FA3 on H100: 2.5x (software + hardware)

23.3 Sequence Parallelism

When a single GPU can’t fit the full sequence, split it across GPUs.

23.3.1 The Problem

For very long sequences, even FlashAttention hits limits:

Activations per layer (transformer):
  = batch × seq × hidden × num_tensors
  = 1 × 100K × 8192 × 4
  = 3.2 GB per layer

80-layer model: 256 GB activations
Plus KV cache, gradients: Doesn't fit in 80GB

23.3.2 Naive Sequence Parallelism

Split the sequence across GPUs:

Sequence: [token_0, ..., token_99999] (100K tokens)

GPU 0: [token_0, ..., token_24999]
GPU 1: [token_25000, ..., token_49999]
GPU 2: [token_50000, ..., token_74999]
GPU 3: [token_75000, ..., token_99999]

Problem: Attention is all-to-all. Token 50000 needs to attend to token 0.

23.3.3 Ring Attention

Key insight: Stream K/V around a ring of GPUs while computing attention incrementally.

Ring Attention for 4 GPUs:

Initial state:
  GPU 0: Q_0, K_0, V_0
  GPU 1: Q_1, K_1, V_1
  GPU 2: Q_2, K_2, V_2
  GPU 3: Q_3, K_3, V_3

Step 1: Each GPU computes local attention + sends K,V to next GPU
  GPU 0: Attend(Q_0, K_0, V_0) → partial_0
         Send K_0, V_0 → GPU 1
  GPU 1: Attend(Q_1, K_1, V_1) → partial_1
         Send K_1, V_1 → GPU 2
  ...

Step 2: Receive K,V from previous GPU, continue attention
  GPU 0: Receive K_3, V_3
         Attend(Q_0, K_3, V_3) → update partial_0
  GPU 1: Receive K_0, V_0
         Attend(Q_1, K_0, V_0) → update partial_1
  ...

After N-1 steps: Each GPU has attended to all K,V

The magic: Uses FlashAttention’s online softmax trick. Partial results can be combined as new K/V arrive.

def ring_attention(Q_local, K_local, V_local, world_size, rank):
    """
    Ring attention: stream K/V around ring while computing attention.
    """
    # Initialize with local attention
    O, lse = flash_attention_with_lse(Q_local, K_local, V_local)

    # Ring communication
    K_recv = torch.empty_like(K_local)
    V_recv = torch.empty_like(V_local)

    for step in range(world_size - 1):
        # Async send K,V to next GPU
        send_rank = (rank + 1) % world_size
        recv_rank = (rank - 1) % world_size

        send_req = dist.isend(K_local, send_rank)
        dist.recv(K_recv, recv_rank)
        send_req.wait()

        send_req = dist.isend(V_local, send_rank)
        dist.recv(V_recv, recv_rank)
        send_req.wait()

        # Compute attention with received K,V
        O_new, lse_new = flash_attention_with_lse(Q_local, K_recv, V_recv)

        # Online update (same as FlashAttention chunking)
        O, lse = online_softmax_update(O, lse, O_new, lse_new)

        # Prepare for next iteration
        K_local, K_recv = K_recv, K_local
        V_local, V_recv = V_recv, V_local

    return O

def online_softmax_update(O1, lse1, O2, lse2):
    """
    Combine two partial attention results.

    lse = log(sum(exp(scores))) - the log-sum-exp for normalization
    """
    # New log-sum-exp
    lse_max = torch.maximum(lse1, lse2)
    lse_new = lse_max + torch.log(
        torch.exp(lse1 - lse_max) + torch.exp(lse2 - lse_max)
    )

    # Reweight and combine outputs
    w1 = torch.exp(lse1 - lse_new)
    w2 = torch.exp(lse2 - lse_new)
    O_new = w1 * O1 + w2 * O2

    return O_new, lse_new

23.3.4 Communication Analysis

Ring attention communication pattern:

Per step: Send 2 * seq_local * head_dim * num_heads bytes
Steps: world_size - 1
Total: 2 * (N-1)/N * seq * head_dim * num_heads bytes

For 100K sequence, 4 GPUs, 8192 hidden, FP16:
  = 2 * 3/4 * 100K * 8192 * 2 bytes
  = 2.4 GB total communication

At 400 GB/s NVLink: 6 ms
Compared to compute: Often overlapped with attention

Key advantage: Communication is pipelined with compute. While computing attention on current K/V, receiving next K/V.

23.3.5 Causal Attention Optimization

For causal (autoregressive) attention, tokens only attend to earlier tokens:

Full attention:              Causal attention:
[× × × ×]                    [× - - -]
[× × × ×]                    [× × - -]
[× × × ×]                    [× × × -]
[× × × ×]                    [× × × ×]

Ring attention optimization: Skip unnecessary K/V blocks.

GPU 0 (tokens 0-24999):
  - Attends to: K_0 only (causal)

GPU 3 (tokens 75000-99999):
  - Attends to: K_0, K_1, K_2, K_3 (all previous)

Result: ~50% less communication for causal attention

23.4 Context Parallelism (Megatron-LM)

Megatron’s approach to long sequences combines sequence parallelism with tensor parallelism.

23.4.1 Ulysses: Sequence Parallel Attention

Split sequence dimension, use all-to-all for attention:

def ulysses_attention(Q, K, V, seq_parallel_group):
    """
    Ulysses-style sequence parallel attention.

    Q, K, V are local sequence chunks.
    """
    # All-gather Q, K, V across sequence parallel group
    Q_full = all_gather(Q, group=seq_parallel_group)
    K_full = all_gather(K, group=seq_parallel_group)
    V_full = all_gather(V, group=seq_parallel_group)

    # Full attention (each GPU computes full attention)
    O_full = flash_attention(Q_full, K_full, V_full)

    # Scatter output back
    O_local = scatter(O_full, group=seq_parallel_group)

    return O_local

Tradeoff vs Ring Attention: - Ulysses: All-gather (simpler), then local compute - Ring: Pipelined communication + compute (more complex, better overlap)

23.4.2 Combining Parallelism Dimensions

Real systems use multiple dimensions:

Example: 1M token training on 128 GPUs

Tensor Parallel: 8-way (within node)
Sequence Parallel: 4-way (across 4 nodes)
Data Parallel: 4-way (replicas)

8 × 4 × 4 = 128 GPUs

Each GPU sees: 1M / 4 / 4 = 62.5K local sequence
Still need ring attention for 62.5K on 8 GPUs

23.5 Hierarchical Attention

For extremely long contexts, hierarchical approaches reduce complexity.

23.5.1 Block-wise Attention

Divide sequence into blocks, full attention within blocks, sparse across blocks:

Block size: 4096 tokens
Sequence: 1M tokens = 244 blocks

Within block: Full O(n²) attention → O(4096²) = 16M ops per block
Across blocks: Summary vectors or sparse patterns

Total: O(n * block_size) instead of O(n²)
     = O(1M * 4096) = 4B ops
     vs O(1M²) = 1T ops

Examples: Longformer, BigBird, LongT5

23.5.2 Retrieval-Augmented Attention

Don’t attend to everything—retrieve relevant context:

def retrieval_attention(query, context_db, k=100):
    """
    Attend only to retrieved relevant context.
    """
    # Retrieve top-k relevant chunks
    relevant_chunks = context_db.search(query, k=k)

    # Concatenate and attend
    context = torch.cat(relevant_chunks)
    output = attention(query, context, context)

    return output

Tradeoff: Loses some information but enables arbitrary context length.

23.6 KV Cache for Long Contexts

Inference with long contexts has unique challenges.

23.6.1 The KV Cache Problem

LLaMA-70B KV cache:
  = 2 × layers × seq × head_dim × 2 bytes
  = 2 × 80 × 100K × 128 × 2
  = 4 GB per sequence

For 8 concurrent sequences: 32 GB just for KV cache
Leaves only 48 GB for model weights and compute

23.6.2 KV Cache Compression

H2O (Heavy-Hitter Oracle): Keep only important tokens.

def h2o_kv_cache(K, V, attention_scores, budget):
    """
    Keep only high-attention tokens in cache.
    """
    # Compute cumulative attention each token receives
    importance = attention_scores.sum(dim=-2)  # Sum over queries

    # Keep top tokens
    keep_indices = importance.topk(budget).indices
    K_compressed = K[:, keep_indices]
    V_compressed = V[:, keep_indices]

    return K_compressed, V_compressed

StreamingLLM: Keep first tokens (attention sinks) + recent window.

def streaming_llm_cache(K, V, window_size, sink_size=4):
    """
    Keep attention sinks + sliding window.
    """
    K_sink = K[:, :sink_size]
    V_sink = V[:, :sink_size]

    K_window = K[:, -window_size:]
    V_window = V[:, -window_size:]

    return torch.cat([K_sink, K_window], dim=1), \
           torch.cat([V_sink, V_window], dim=1)

23.6.3 Quantized KV Cache

Reduce precision for KV cache:

FP16 KV cache: 4 GB per 100K sequence
INT8 KV cache: 2 GB per 100K sequence
INT4 KV cache: 1 GB per 100K sequence

Quality impact: Usually <1% perplexity degradation

23.7 Practical Considerations

23.7.1 Positional Encodings for Long Context

Without positional encodings, self-attention is permutation-equivariant—the same output for any ordering of tokens. For sequences, we need position to matter. Positional encodings break this symmetry intentionally.

But for long contexts, the choice of positional encoding becomes critical.

23.7.1.1 The Problem with Traditional Approaches

Learned positions:
  - Fixed vocabulary of positions (e.g., 0-2047)
  - Cannot extrapolate beyond training length
  - Must retrain for longer contexts

Sinusoidal positions:
  - Can represent any position mathematically
  - But attention patterns learned on positions 0-2047
    don't generalize to position 100,000
  - The model hasn't learned what "far away" means

The fundamental issue: traditional position encodings encode absolute position. But attention really cares about relative position—how far apart are two tokens?

23.7.1.2 Rotary Position Embeddings (RoPE)

RoPE solves this elegantly using rotation matrices. The key insight: rotations naturally encode relative position through their composition properties.

The Mathematical Foundation

A 2D rotation matrix has the form:

\[R(\theta) = \begin{pmatrix} \cos\theta & -\sin\theta \\ \sin\theta & \cos\theta \end{pmatrix}\]

Rotation matrices have special properties that make them perfect for positional encoding:

Property Mathematical Form Why It Enables RoPE
Orthogonality \(R^T R = I\) Preserves vector norms and dot products
Relative composition \(R_m^T R_n = R_{n-m}\) Relative position from absolute
Commutativity \(R_\theta R_\phi = R_{\theta+\phi}\) Positions compose cleanly

The crucial property is the second one. When we compute attention between positions \(m\) and \(n\):

\[(R_m q)^T (R_n k) = q^T R_m^T R_n k = q^T R_{n-m} k\]

The absolute positions \(m\) and \(n\) vanish—only the relative position \(n-m\) remains. This is exactly what we want for attention.

Applying Rotation to Embeddings

RoPE treats consecutive pairs of embedding dimensions as 2D vectors and rotates each pair:

def rope_embedding(x, positions, base=10000):
    """
    Apply Rotary Position Embedding to input tensor.

    x: [batch, seq, heads, head_dim]
    positions: [seq] - position indices
    """
    head_dim = x.shape[-1]

    # Different frequency for each dimension pair
    # Lower dimensions = lower frequency = longer wavelength
    dim_pairs = head_dim // 2
    freqs = 1.0 / (base ** (torch.arange(0, dim_pairs) / dim_pairs))

    # Compute rotation angles: position × frequency
    # [seq, dim_pairs]
    angles = positions.unsqueeze(-1) * freqs.unsqueeze(0)

    # Apply rotation to each pair of dimensions
    x_pairs = x.view(*x.shape[:-1], dim_pairs, 2)

    cos = torch.cos(angles)
    sin = torch.sin(angles)

    # Rotation: [x, y] → [x·cos - y·sin, x·sin + y·cos]
    x_rot = torch.stack([
        x_pairs[..., 0] * cos - x_pairs[..., 1] * sin,
        x_pairs[..., 0] * sin + x_pairs[..., 1] * cos
    ], dim=-1)

    return x_rot.view(*x.shape)

Multi-Frequency Encoding

Each dimension pair uses a different rotation frequency:

Dimension pair 0: θ = position × (1 / base^0)     = position
Dimension pair 1: θ = position × (1 / base^(1/d))
Dimension pair 2: θ = position × (1 / base^(2/d))
...
Dimension pair d: θ = position × (1 / base^1)     = position / base

Low dimensions:  Fast rotation, fine position detail
High dimensions: Slow rotation, coarse position info

This is analogous to Fourier features or the sinusoidal encoding—but applied through rotations that preserve the relative-position property.

Why RoPE Works Better for Long Contexts

The relative-position formulation helps RoPE extrapolate:

Traditional absolute position:
  - Train on positions 0-4096
  - Position 10000 is "unknown" to the model

RoPE relative position:
  - Train on relative distances -4096 to +4096
  - At position 10000, nearby tokens still have
    familiar relative distances
  - Long-range attention may degrade, but
    local attention remains well-calibrated

23.7.1.3 RoPE Extensions for Extreme Lengths

Even RoPE degrades beyond training length. Several extensions address this:

NTK-Aware Scaling

Key insight: don’t just interpolate positions—adjust the frequency base.

def ntk_aware_scaling(base, scale_factor, dim):
    """
    NTK-aware interpolation: adjust base frequency for length extension.

    Instead of: position → position / scale_factor (interpolation)
    Do:         base → base × scale_factor^(dim/(dim-2))
    """
    # This preserves relative position resolution better
    # than naive position interpolation
    alpha = scale_factor ** (dim / (dim - 2))
    return base * alpha

The intuition: NTK (Neural Tangent Kernel) theory suggests the frequency spectrum matters. Naive interpolation compresses high frequencies too aggressively; NTK-aware scaling maintains the frequency distribution better.

YaRN (Yet another RoPE extensioN)

YaRN combines multiple techniques:

  1. NTK-aware interpolation for the base frequency
  2. Attention scaling to compensate for longer contexts
  3. Dimension-dependent interpolation (different scaling per frequency)
def yarn_scale_factor(dim_idx, dim, scale, alpha=1, beta=32):
    """
    YaRN: dimension-dependent interpolation factor.

    Low-frequency dimensions: interpolate more aggressively
    High-frequency dimensions: preserve more (local info)
    """
    # Linear ramp from 0 to 1 based on dimension
    ramp = (dim_idx / dim - alpha) / (beta - alpha)
    ramp = min(1, max(0, ramp))

    # Blend between no scaling and full scaling
    return 1 / (1 - ramp + ramp * scale)

ALiBi (Attention with Linear Biases)

A different approach: don’t modify embeddings, add position bias directly to attention scores.

def alibi_attention(Q, K, V, num_heads):
    """
    ALiBi: add linear position bias to attention.
    """
    seq_len = Q.shape[1]

    # Relative position matrix
    positions = torch.arange(seq_len)
    relative_pos = positions.unsqueeze(0) - positions.unsqueeze(1)

    # Different slope per head (geometric sequence)
    slopes = 2 ** (-8 / num_heads * torch.arange(1, num_heads + 1))

    # Attention scores with position bias
    scores = Q @ K.transpose(-1, -2) / sqrt(d)

    # Add negative bias for distant positions
    # (encourages attending to nearby tokens)
    for h, slope in enumerate(slopes):
        scores[:, h] -= slope * torch.abs(relative_pos)

    return softmax(scores, dim=-1) @ V

ALiBi’s advantages: - No learned parameters for position - Naturally handles any length (bias is mathematically defined) - Different heads specialize in different ranges

Choosing Your Approach

Approach    Extrapolation    Compute Cost    Compatibility
─────────────────────────────────────────────────────────
Learned     None             Cheap           N/A beyond train length
Sinusoidal  Poor             Cheap           Degrades badly
RoPE        Moderate         Cheap           Good local, degrades long
RoPE+NTK    Good             Cheap           Minor fine-tuning helps
YaRN        Very Good        Cheap           Best with fine-tuning
ALiBi       Excellent        Cheap           No fine-tuning needed
NoteConnection to Symmetry

RoPE is a beautiful example of exploiting rotation symmetry (see Chapter: Symmetry). The SO(2) rotation group has exactly the properties needed for relative positional encoding:

  • Orthogonality preserves the geometry of attention
  • Group composition (\(R_m^T R_n = R_{n-m}\)) naturally encodes relative position
  • The continuous nature of rotation enables smooth interpolation

This is why rotation matrices specifically—not arbitrary transformations—work so well. The algebraic structure of the rotation group matches the structure of what we want to compute.

23.7.2 Memory-Efficient Training

Long-context training needs:

# Gradient checkpointing (mandatory for long contexts)
model = gradient_checkpoint(model, checkpoint_every_n_layers=2)

# Activation offloading
config = ActivationOffloadConfig(
    offload_to='cpu',  # or 'nvme'
    prefetch=True
)

# Mixed precision with careful loss scaling
scaler = GradScaler(init_scale=2**10)  # Start lower for stability

23.7.3 Choosing Your Approach

Sequence Length    Technique
─────────────────────────────────────────────
<16K              Standard FlashAttention
16K-64K           FlashAttention-2/3
64K-256K          Ring Attention (multi-GPU)
256K-1M           Ring Attention + KV compression
>1M               Hierarchical/Retrieval + Ring Attention

23.8 Connections

FlashAttention: Foundation—online softmax enables ring attention.

Distributed: Ring attention is a distributed attention pattern.

Inference: KV cache compression essential for long-context serving.

MoE: MoE + long context = complex distributed setup.

23.9 Key Takeaways

  1. FlashAttention evolves: FA2/FA3 bring 2-4x improvements through better parallelism and hardware utilization.

  2. Ring attention enables scale: Stream K/V around GPU ring while computing, limited by communication bandwidth.

  3. Online softmax is key: Same trick enables both chunked attention and distributed attention.

  4. Causal helps: Autoregressive models skip half the communication.

  5. KV cache dominates inference: Compression and quantization are essential for long-context serving.

  6. No free lunch: All techniques trade something—compute, quality, or complexity.

NoteTry It Yourself

The accompanying notebook lets you:

  • Implement ring attention from scratch
  • Visualize communication patterns
  • Compare FA1 vs FA2 vs FA3 performance
  • Experiment with KV cache compression

Notebook support for this chapter is in progress. For now, run the examples locally and benchmark context-length trade-offs on your hardware.

23.10 Further Reading

  • Dao (2023). “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning”
  • Shah et al. (2024). “FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision”
  • Liu et al. (2023). “Ring Attention with Blockwise Transformers for Near-Infinite Context”
  • Zhang et al. (2023). “H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models”
  • Xiao et al. (2023). “Efficient Streaming Language Models with Attention Sinks”
  • Su et al. (2021). “RoFormer: Enhanced Transformer with Rotary Position Embedding”
  • Peng et al. (2023). “YaRN: Efficient Context Window Extension of Large Language Models”
  • Press et al. (2022). “Train Short, Test Long: Attention with Linear Biases Enables Input Length Generalization” (ALiBi)
  • bloc97 (2023). “NTK-Aware Scaled RoPE allows LLaMA models to have extended context” (Reddit post that introduced NTK-aware scaling)