Sequence Parallelism from Decomposability
When sequences grow to millions of tokens, even a single attention computation won't fit in memory. Sequence parallelism exploits the decomposability of attention to split along the sequence dimension.
The Question: Attention is O(S²) in sequence length. For S=1M tokens, that's 10^12 attention scores. How do we compute this when no single GPU can hold the attention matrix?
Why Sequence Parallelism?¶
As models process longer contexts—documents, codebases, video—the sequence length \(S\) becomes the bottleneck.
Memory Scaling with Sequence Length¶
For a transformer layer:
Activation memory per layer:
Where:
- First term: Input and output activations (\(2 \times b \times S \times H\))
- Second term: Attention matrix (\(b \times n_h \times S \times S\))
The attention matrix scales as \(O(S^2)\).
Example: \(S = 128K\), \(b = 1\), \(n_h = 32\), fp16:
No single GPU can hold this.
The Sequence Dimension¶
Unlike batch or hidden dimensions, the sequence dimension appears in:
- Attention: \(Q K^T\) has shape \((S \times S)\)
- Layer normalization: Statistics computed per-token (independent)
- Feedforward: Applied per-token (independent)
- Positional encoding: Position-dependent
Only attention creates cross-token dependencies.
Two Flavors of Sequence Parallelism¶
1. Megatron Sequence Parallelism¶
Reduces memory for LayerNorm and Dropout activations by distributing across the sequence dimension within tensor parallelism.
Target: Activation memory outside attention.
Communication: AllGather before attention, ReduceScatter after.
2. Context Parallelism (Ring Attention / Ulysses)¶
Distributes the attention computation itself across the sequence dimension.
Target: The \(O(S^2)\) attention matrix.
Communication: Ring of P2P (Ring Attention) or AlltoAll (Ulysses).
Let's examine each in detail.
Megatron Sequence Parallelism¶
Korthikanti et al. (2022) introduced sequence parallelism for memory efficiency.
The Memory Problem¶
In tensor parallelism, certain operations are replicated:
Before LayerNorm: activation shape (b, S, H) replicated on all TP ranks
After LayerNorm: same shape, still replicated
Before Dropout: replicated
After Dropout: replicated
With TP degree \(T\), this wastes \((T-1) \cdot b \cdot S \cdot H\) memory.
The Solution: Distribute Sequence¶
Split the sequence dimension across TP ranks:
Communication Pattern¶
Before column-parallel layer (needs full sequence for each rank):
After row-parallel layer (output is partial, sum across):
┌─────────────┐
│ LayerNorm │ Sequence-parallel
│ (b, S/T, H)│
└──────┬──────┘
│
┌──────▼──────┐
│ AllGather │
│ (b, S, H) │
└──────┬──────┘
│
┌──────▼──────┐
│ Col-Parallel│ Tensor-parallel
│ Linear │
└──────┬──────┘
│
┌──────▼──────┐
│ Row-Parallel│
│ Linear │
└──────┬──────┘
│
┌──────▼──────┐
│ReduceScatter│
│ (b, S/T, H)│
└──────┬──────┘
│
┌──────▼──────┐
│ Dropout │ Sequence-parallel
│ (b, S/T, H)│
└─────────────┘
Memory Savings¶
Without sequence parallelism:
With sequence parallelism:
Savings factor: \(T\) (the tensor parallelism degree).
Communication Analysis¶
Each transformer layer incurs (per rank, ring-style collectives):
- 1 AllGather before attention
- 1 ReduceScatter after attention
- 1 AllGather before FFN
- 1 ReduceScatter after FFN
Total per layer (per rank): \(4 \cdot \frac{P-1}{P} \cdot b \cdot S \cdot H \cdot \text{sizeof(dtype)}\)
This is the same volume as tensor parallelism's AllReduce operations, just restructured.
Implementation¶
class SequenceParallelLayerNorm(nn.Module):
"""LayerNorm operating on sequence-parallel inputs."""
def __init__(self, hidden_size: int, tp_group):
super().__init__()
self.tp_group = tp_group
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.eps = 1e-5
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (batch, seq_local, hidden)
# LayerNorm is per-token, so no cross-rank communication needed
return F.layer_norm(x, (x.size(-1),), self.weight, self.bias, self.eps)
def sequence_parallel_attention(q, k, v, tp_group):
"""Attention with sequence-parallel input/output."""
# Input: (batch, seq_local, hidden)
# AllGather to get full sequence
q_full = all_gather_sequence(q, tp_group) # (batch, seq, hidden)
k_full = all_gather_sequence(k, tp_group)
v_full = all_gather_sequence(v, tp_group)
# Standard attention on full sequence
output = attention(q_full, k_full, v_full) # (batch, seq, hidden)
# Shard back to local sequence chunk.
# This pedagogical helper is a pure scatter (no reduction).
output_local = scatter_sequence_chunk(output, tp_group)
return output_local # (batch, seq_local, hidden)
def all_gather_sequence(x: torch.Tensor, group) -> torch.Tensor:
"""AllGather along sequence dimension."""
world_size = dist.get_world_size(group)
gathered = [torch.empty_like(x) for _ in range(world_size)]
dist.all_gather(gathered, x, group=group)
return torch.cat(gathered, dim=1) # Concat along seq dimension
def scatter_sequence_chunk(x: torch.Tensor, group) -> torch.Tensor:
"""Shard tensor along sequence dimension (no reduction)."""
world_size = dist.get_world_size(group)
seq_len = x.size(1)
if seq_len % world_size != 0:
raise ValueError("Sequence length must be divisible by world size (pad or use uneven splits).")
local_seq = seq_len // world_size
# Split into chunks
chunks = x.split(local_seq, dim=1)
rank = dist.get_rank(group)
return chunks[rank].contiguous()
The Decomposability of Attention¶
For true sequence parallelism—partitioning the attention computation itself—we need to understand how attention can be decomposed.
Standard Attention¶
The challenge: softmax normalizes across the entire key sequence:
We can't compute the denominator without seeing all keys.
Online Softmax¶
Milakov & Gimelshein (2018) showed softmax can be computed incrementally.
The insight: Track the running maximum and sum.
For numerical stability:
where \(m = \max_j x_j\).
Incremental update: Given chunks \(A\) and \(B\):
Output update:
Associativity of Online Softmax¶
Theorem: Online softmax combination is associative.
Let \((m, s, o)\) represent the state (max, sum, output). The combination operation \(\oplus\):
is associative:
Proof sketch: The max operation is associative. The sum and output updates are weighted averages with weights determined by exponentiated differences from the global max. The final result depends only on all inputs, not on combination order. \(\square\)
This associativity enables distributed computation.
Ring Attention¶
Liu et al. (2023) introduced Ring Attention for extremely long sequences.
The Core Idea¶
Each GPU holds:
- Query chunk \(Q_i\): Local queries (never moves)
- KV buffer: Key-value pairs that rotate around the ring
Initial:
GPU 0: Q₀, K₀, V₀
GPU 1: Q₁, K₁, V₁
GPU 2: Q₂, K₂, V₂
GPU 3: Q₃, K₃, V₃
After step 1 (K, V rotate):
GPU 0: Q₀, K₃, V₃
GPU 1: Q₁, K₀, V₀
GPU 2: Q₂, K₁, V₁
GPU 3: Q₃, K₂, V₂
The Algorithm¶
def ring_attention(Q_local, K_local, V_local, ring_group):
"""
Ring Attention: Compute attention over distributed sequence.
Args:
Q_local: Local query chunk (batch, seq_local, heads, dim)
K_local: Local key chunk (batch, seq_local, heads, dim)
V_local: Local value chunk (batch, seq_local, heads, dim)
ring_group: Process group for ring communication
Returns:
Output: Attention output for local queries
"""
world_size = dist.get_world_size(ring_group)
rank = dist.get_rank(ring_group)
# Initialize output accumulator with online softmax state
batch, seq_local, heads, dim = Q_local.shape
output = torch.zeros(batch, seq_local, heads, dim, device=Q_local.device)
max_scores = torch.full((batch, seq_local, heads, 1), float('-inf'),
device=Q_local.device)
sum_exp = torch.zeros(batch, seq_local, heads, 1, device=Q_local.device)
# Current K, V buffers (will rotate)
K_current = K_local.clone()
V_current = V_local.clone()
# Buffers for async communication
K_recv = torch.empty_like(K_current)
V_recv = torch.empty_like(V_current)
for step in range(world_size):
# Start async receive from previous rank
if step < world_size - 1:
src = (rank - 1) % world_size
recv_k = dist.irecv(K_recv, src=src, group=ring_group)
recv_v = dist.irecv(V_recv, src=src, group=ring_group)
# Compute local attention scores
# Q_local @ K_current^T -> (batch, seq_local, heads, seq_local)
scores = torch.einsum('bqhd,bkhd->bqhk', Q_local, K_current)
scores = scores / math.sqrt(dim)
# Apply causal mask if needed
kv_offset = ((rank - step) % world_size) * seq_local
if kv_offset > 0: # Keys are from "future" positions
causal_mask = create_causal_mask(seq_local, kv_offset)
scores = scores.masked_fill(causal_mask, float('-inf'))
# Online softmax update
chunk_max = scores.max(dim=-1, keepdim=True).values
new_max = torch.maximum(max_scores, chunk_max)
# Rescale previous sum and output
scale_old = torch.exp(max_scores - new_max)
scale_new = torch.exp(chunk_max - new_max)
exp_scores = torch.exp(scores - chunk_max)
chunk_sum = exp_scores.sum(dim=-1, keepdim=True)
# Update running sum
sum_exp = sum_exp * scale_old + chunk_sum * scale_new
# Update output
chunk_output = torch.einsum('bqhk,bkhd->bqhd', exp_scores, V_current)
output = output * scale_old + chunk_output * scale_new
max_scores = new_max
# Send K, V to next rank (async)
if step < world_size - 1:
dst = (rank + 1) % world_size
send_k = dist.isend(K_current, dst=dst, group=ring_group)
send_v = dist.isend(V_current, dst=dst, group=ring_group)
# Wait for receive and swap buffers
recv_k.wait()
recv_v.wait()
K_current, K_recv = K_recv, K_current
V_current, V_recv = V_recv, V_current
# Wait for send to complete
send_k.wait()
send_v.wait()
# Final normalization
output = output / sum_exp
return output
Communication Pattern¶
Each step:
- Send \(K_i, V_i\) to next rank: \(2 \cdot S/P \cdot H\) elements
- Receive from previous rank: same
Total communication per attention layer:
Critical feature: Communication overlaps with computation.
Compute-Communication Overlap¶
While computing attention with current K, V:
- Simultaneously send current K, V to next rank
- Simultaneously receive next K, V from previous rank
Overlap efficiency:
Where:
For large sequences, compute dominates and overlap is nearly perfect.
Memory Analysis¶
Peak memory per GPU:
- Query: \(b \cdot \frac{S}{P} \cdot H\)
- Two KV buffers (double buffering): \(2 \cdot 2 \cdot b \cdot \frac{S}{P} \cdot H\)
- Attention scores (one chunk): \(b \cdot n_h \cdot \frac{S}{P} \cdot \frac{S}{P}\)
- Output accumulator: \(b \cdot \frac{S}{P} \cdot H\)
Total:
The quadratic term is now \((S/P)^2\) instead of \(S^2\) — a factor of \(P^2\) reduction.
Ulysses: AlltoAll Sequence Parallelism¶
Fang et al. (2024) proposed Ulysses as an alternative to Ring Attention.
The Approach¶
Instead of rotating K, V through a ring, use AlltoAll to redistribute:
- Initial state: Each GPU has local Q, K, V for sequence chunk
- AlltoAll on K: Redistribute so each GPU has all K for some heads
- AlltoAll on V: Same redistribution
- Local attention: Compute attention with full sequence for subset of heads
- AlltoAll on output: Redistribute back to sequence-parallel layout
Before AlltoAll (sequence-parallel):
GPU 0: Q[0:S/P], K[0:S/P], V[0:S/P] for all heads
GPU 1: Q[S/P:2S/P], K[S/P:2S/P], V[S/P:2S/P] for all heads
After AlltoAll (head-parallel):
GPU 0: Q[0:S], K[0:S], V[0:S] for heads 0:H/P
GPU 1: Q[0:S], K[0:S], V[0:S] for heads H/P:2H/P
Implementation¶
def ulysses_attention(Q, K, V, sp_group):
"""
Ulysses sequence parallelism using AlltoAll.
Args:
Q, K, V: Local chunks (batch, seq_local, heads, dim)
sp_group: Sequence parallel process group
Returns:
Output attention for local sequence chunk
"""
world_size = dist.get_world_size(sp_group)
batch, seq_local, heads, dim = Q.shape
# Reshape for AlltoAll: split heads dimension
# (batch, seq_local, heads, dim) -> (batch, seq_local, P, heads/P, dim)
heads_per_rank = heads // world_size
Q = Q.view(batch, seq_local, world_size, heads_per_rank, dim)
K = K.view(batch, seq_local, world_size, heads_per_rank, dim)
V = V.view(batch, seq_local, world_size, heads_per_rank, dim)
# AlltoAll: exchange sequence chunks for head chunks
# After: each rank has full sequence for subset of heads
Q = all_to_all(Q, dim_scatter=2, dim_gather=1, group=sp_group)
K = all_to_all(K, dim_scatter=2, dim_gather=1, group=sp_group)
V = all_to_all(V, dim_scatter=2, dim_gather=1, group=sp_group)
# Now shape is (batch, seq_full, heads_local, dim)
seq_full = seq_local * world_size
Q = Q.view(batch, seq_full, heads_per_rank, dim)
K = K.view(batch, seq_full, heads_per_rank, dim)
V = V.view(batch, seq_full, heads_per_rank, dim)
# Standard attention on full sequence (for local heads)
output = flash_attention(Q, K, V) # (batch, seq_full, heads_local, dim)
# Reshape for reverse AlltoAll
output = output.view(batch, world_size, seq_local, heads_per_rank, dim)
# AlltoAll: exchange head chunks for sequence chunks
output = all_to_all(output, dim_scatter=1, dim_gather=2, group=sp_group)
# Reshape back
output = output.view(batch, seq_local, heads, dim)
return output
def all_to_all(x, dim_scatter, dim_gather, group):
"""AlltoAll with specified scatter and gather dimensions."""
world_size = dist.get_world_size(group)
# Split along scatter dimension
splits = x.chunk(world_size, dim=dim_scatter)
splits = [s.contiguous() for s in splits]
# AlltoAll
output_splits = [torch.empty_like(splits[0]) for _ in range(world_size)]
dist.all_to_all(output_splits, splits, group=group)
# Concatenate along gather dimension
return torch.cat(output_splits, dim=dim_gather)
Communication Analysis¶
AlltoAll volume (each direction):
Per attention layer: 4 AlltoAll operations (Q, K, V in; output out).
Total:
Ring vs Ulysses Comparison¶
| Aspect | Ring Attention | Ulysses |
|---|---|---|
| Communication | P2P in ring | AlltoAll |
| Volume | \(2 \cdot \frac{P-1}{P} \cdot S \cdot H\) | \(4 \cdot \frac{P-1}{P^2} \cdot S \cdot H\) |
| Overlap | Yes (compute + comm) | Limited |
| Memory | KV buffers | No extra buffers |
| Best for | Long seq, P2P fast | Few ranks, AlltoAll fast |
When to use Ring: Many sequence parallel ranks, can overlap.
When to use Ulysses: Few ranks (2-8), high AlltoAll bandwidth (NVLink).
Flash Attention Integration¶
Both Ring and Ulysses benefit from Flash Attention.
Flash Attention Recap¶
Dao et al. (2022) compute attention in tiles without materializing the full \(S \times S\) matrix:
def flash_attention_forward(Q, K, V, block_size=64):
"""
Flash Attention: memory-efficient attention using tiling.
"""
batch, seq_q, heads, dim = Q.shape
seq_k = K.shape[1]
output = torch.zeros_like(Q)
max_scores = torch.full((batch, seq_q, heads, 1), float('-inf'))
sum_exp = torch.zeros((batch, seq_q, heads, 1))
# Tile over K, V
for k_start in range(0, seq_k, block_size):
k_end = min(k_start + block_size, seq_k)
K_block = K[:, k_start:k_end]
V_block = V[:, k_start:k_end]
# Tile over Q
for q_start in range(0, seq_q, block_size):
q_end = min(q_start + block_size, seq_q)
Q_block = Q[:, q_start:q_end]
# Compute attention scores for this tile
scores = torch.einsum('bqhd,bkhd->bqhk', Q_block, K_block)
scores = scores / math.sqrt(dim)
# Online softmax update (same as Ring Attention)
block_max = scores.max(dim=-1, keepdim=True).values
new_max = torch.maximum(max_scores[:, q_start:q_end], block_max)
scale_old = torch.exp(max_scores[:, q_start:q_end] - new_max)
scale_new = torch.exp(block_max - new_max)
exp_scores = torch.exp(scores - block_max)
block_sum = exp_scores.sum(dim=-1, keepdim=True)
# Update accumulators
sum_exp[:, q_start:q_end] = (
sum_exp[:, q_start:q_end] * scale_old +
block_sum * scale_new
)
block_out = torch.einsum('bqhk,bkhd->bqhd', exp_scores, V_block)
output[:, q_start:q_end] = (
output[:, q_start:q_end] * scale_old +
block_out * scale_new
)
max_scores[:, q_start:q_end] = new_max
# Final normalization
output = output / sum_exp
return output
Ring Attention with Flash Attention¶
The inner loop of Ring Attention can use Flash Attention for the local computation:
def ring_flash_attention(Q_local, K_local, V_local, ring_group):
"""Ring Attention using Flash Attention for each step."""
world_size = dist.get_world_size(ring_group)
# Initialize accumulators
output, max_scores, sum_exp = init_accumulators(Q_local)
K_current, V_current = K_local.clone(), V_local.clone()
for step in range(world_size):
# Start async communication (overlapped)
comm_handle = start_ring_comm(K_current, V_current, ring_group)
# Use Flash Attention kernel for this chunk
# Returns (output_chunk, max_chunk, sum_chunk)
chunk_out, chunk_max, chunk_sum = flash_attention_with_state(
Q_local, K_current, V_current
)
# Online softmax merge
output, max_scores, sum_exp = merge_attention_state(
output, max_scores, sum_exp,
chunk_out, chunk_max, chunk_sum
)
# Complete communication
K_current, V_current = complete_ring_comm(comm_handle)
return output / sum_exp
Hybrid Context Parallelism¶
For very long sequences, combine techniques.
Hierarchical Ring¶
Use multiple rings at different levels:
Inter-node ring (slow network):
[Node 0] ←→ [Node 1] ←→ [Node 2] ←→ [Node 3]
Intra-node ring (NVLink):
GPU0 ←→ GPU1 ←→ GPU2 ←→ GPU3 (within each node)
Combined Strategies¶
class HybridContextParallel:
"""Combine Ulysses (intra-node) with Ring (inter-node)."""
def __init__(self, local_group, global_ring_group):
self.local_group = local_group # GPUs within node
self.ring_group = global_ring_group # Across nodes
def forward(self, Q, K, V):
# Step 1: Ulysses within node (fast AlltoAll via NVLink)
Q, K, V = ulysses_qkv_exchange(Q, K, V, self.local_group)
# Step 2: Ring across nodes (overlapped P2P)
output = ring_attention(Q, K, V, self.ring_group)
# Step 3: Ulysses to restore layout
output = ulysses_output_exchange(output, self.local_group)
return output
Memory Comparison¶
For sequence length \(S\), hidden size \(H\), heads \(n_h\), parallel degree \(P\):
| Method | Attention Matrix Memory | KV Memory |
|---|---|---|
| Standard | \(O(n_h \cdot S^2)\) | \(O(S \cdot H)\) |
| Flash Attention | \(O(n_h \cdot B^2)\) (block size \(B\)) | \(O(S \cdot H)\) |
| Ring Attention | \(O(n_h \cdot (S/P)^2)\) | \(O(S \cdot H / P)\) |
| Ulysses | \(O((n_h/P) \cdot S^2)\) | \(O(S \cdot H)\) |
Ring Attention reduces both dimensions by \(P\); best for very long sequences.
Ulysses reduces head dimension by \(P\); keeps full sequence per GPU.
Causal Masking Considerations¶
Causal attention masks out future tokens: \(\text{Mask}[i, j] = 1\) if \(j \leq i\).
Ring Attention Causality¶
When rotating K, V, we must track which positions they represent:
def get_causal_mask(q_offset, k_offset, seq_local):
"""Create causal mask for Ring Attention step."""
q_positions = torch.arange(q_offset, q_offset + seq_local)
k_positions = torch.arange(k_offset, k_offset + seq_local)
# Mask: True where k_pos > q_pos (future positions)
mask = k_positions.unsqueeze(0) > q_positions.unsqueeze(1)
return mask
Optimization: Skip Future Chunks¶
If all keys in a chunk are "future" relative to all queries, skip computation entirely:
def should_compute_chunk(q_start, q_end, k_start, k_end):
"""Check if KV chunk has any valid (non-future) positions."""
# All keys are future if k_start > q_end
return k_start <= q_end
This can reduce computation by up to 50% for causal attention.
Implementation: Complete Sequence Parallel Layer¶
class SequenceParallelTransformerLayer(nn.Module):
"""Complete transformer layer with sequence parallelism."""
def __init__(self, config, sp_group, tp_group):
super().__init__()
self.sp_group = sp_group
self.tp_group = tp_group
self.sp_size = dist.get_world_size(sp_group)
# Layer components
self.ln1 = SequenceParallelLayerNorm(config.hidden_size, sp_group)
self.attention = SequenceParallelAttention(config, sp_group, tp_group)
self.ln2 = SequenceParallelLayerNorm(config.hidden_size, sp_group)
self.ffn = SequenceParallelFFN(config, sp_group, tp_group)
def forward(self, x):
# x: (batch, seq_local, hidden) - sequence-parallel input
# Pre-norm attention
residual = x
x = self.ln1(x)
x = self.attention(x) # Handles AllGather/ReduceScatter internally
x = residual + x
# Pre-norm FFN
residual = x
x = self.ln2(x)
x = self.ffn(x) # Handles AllGather/ReduceScatter internally
x = residual + x
return x
class SequenceParallelAttention(nn.Module):
"""Attention with context parallelism (Ring or Ulysses)."""
def __init__(self, config, sp_group, tp_group, method='ring'):
super().__init__()
self.sp_group = sp_group
self.tp_group = tp_group
self.method = method
self.num_heads = config.num_heads
self.head_dim = config.hidden_size // config.num_heads
# Projections (tensor-parallel)
self.qkv_proj = ColumnParallelLinear(
config.hidden_size,
3 * config.hidden_size,
tp_group
)
self.out_proj = RowParallelLinear(
config.hidden_size,
config.hidden_size,
tp_group
)
def forward(self, x):
batch, seq_local, hidden = x.shape
# Project to Q, K, V
qkv = self.qkv_proj(x)
q, k, v = qkv.chunk(3, dim=-1)
# Reshape for attention
q = q.view(batch, seq_local, self.num_heads, self.head_dim)
k = k.view(batch, seq_local, self.num_heads, self.head_dim)
v = v.view(batch, seq_local, self.num_heads, self.head_dim)
# Context-parallel attention
if self.method == 'ring':
output = ring_attention(q, k, v, self.sp_group)
else: # ulysses
output = ulysses_attention(q, k, v, self.sp_group)
# Reshape and project
output = output.view(batch, seq_local, hidden)
output = self.out_proj(output)
return output
Exercises¶
-
Memory calculation: For a model with \(H = 4096\), \(n_h = 32\), \(S = 256K\), batch 1, bf16:
-
Calculate attention matrix memory without sequence parallelism
- Calculate with \(P = 8\) Ring Attention
- What's the reduction factor?
Solution
Given:
- Hidden size: \(H = 4096\)
- Number of heads: \(n_h = 32\)
- Sequence length: \(S = 256K = 262,144\)
- Batch size: \(b = 1\)
- Data type: bf16 (2 bytes)
Attention matrix memory without sequence parallelism:
The attention matrix has shape \((b, n_h, S, S)\):
This is impossible to fit on any single GPU.
With \(P = 8\) Ring Attention:
Each GPU processes local queries against local KV chunks. The attention matrix per GPU has shape \((b, n_h, S/P, S/P)\):
This fits in an 80GB H100!
Reduction factor:
Summary:
| Configuration | Attention Matrix Memory | Fits in 80GB? |
|---|---|---|
| No SP | 4.4 TB | No |
| Ring (\(P=8\)) | 68.7 GB | Yes |
Note: Ring Attention also needs KV buffers (~\(4 \times S/P \times H \times 2 = 2\) GB with double buffering), which easily fits.
- Communication volume: Compare Ring Attention and Ulysses communication volume for \(S = 1M\), \(H = 8192\), \(P = 16\). Which uses less bandwidth?
Solution
Given:
- Sequence length: \(S = 1M = 1,048,576\)
- Hidden size: \(H = 8192\)
- Parallelism degree: \(P = 16\)
- Assume bf16 (2 bytes per element)
Ring Attention communication volume:
Ring Attention rotates K and V tensors through the ring. Each GPU sends its local K, V to the next GPU in each step.
Local K, V each have shape \((S/P, H)\):
Total communication over \(P - 1 = 15\) steps:
Ulysses communication volume:
Ulysses performs AlltoAll twice: 1. Before attention: Redistribute Q, K, V from \((S/P, H)\) to \((S, H/P)\) 2. After attention: Redistribute output back
Each AlltoAll moves all data between all GPUs. Volume per AlltoAll for Q, K, V:
Output AlltoAll:
Total Ulysses:
Comparison:
| Method | Total Volume | Per-Step Latency |
|---|---|---|
| Ring Attention | 32.2 GB | 15 P2P transfers |
| Ulysses | 4.03 GB | 2 AlltoAll collectives |
Ulysses uses 8× less bandwidth in total volume. However:
- Ring: Communication overlaps with compute (can hide latency)
- Ulysses: AlltoAll is blocking but lower total volume
When to choose each:
- Ring: When compute >> communication, good P2P bandwidth
- Ulysses: When P is small, AlltoAll is fast (NVLink within node)
- Overlap efficiency: If attention compute takes 10ms and KV transfer takes 2ms per Ring step, what fraction of communication is hidden? With \(P = 8\) steps, what's the effective communication overhead?
Solution
Given:
- Attention compute time: \(T_{\text{compute}} = 10\) ms per step
- KV transfer time: \(T_{\text{comm}} = 2\) ms per step
- Number of steps: \(P = 8\)
Ring Attention overlap model:
In Ring Attention, communication is overlapped with compute using double buffering: - While computing attention on current KV chunk, transfer next KV chunk - Communication is hidden if \(T_{\text{comm}} \leq T_{\text{compute}}\)
Fraction of communication hidden:
Since \(T_{\text{comm}} = 2\) ms \(< T_{\text{compute}} = 10\) ms:
All communication is hidden behind compute!
Total execution time:
With perfect overlap, the total time is dominated by compute, plus the initial transfer (which can't be overlapped):
Without overlap:
Effective communication overhead:
Summary:
| Metric | Value |
|---|---|
| Communication hidden | 100% |
| Total time (with overlap) | 82 ms |
| Total time (no overlap) | 96 ms |
| Speedup from overlap | 1.17× |
| Effective overhead | 2.4% |
Key insight: When compute time exceeds communication time, Ring Attention achieves near-perfect overlap. The only exposed communication is the initial KV fetch before the first compute step.
When overlap breaks down:
If \(T_{\text{comm}} > T_{\text{compute}}\), some communication is exposed:
- Causal optimization: For Ring Attention with \(P = 8\) and causal masking, how many of the 8 steps can be skipped (on average) due to all-future chunks?
Solution
Causal masking in Ring Attention:
With causal masking, position \(i\) can only attend to positions \(j \leq i\). When K, V chunks contain only future tokens relative to all Q tokens, the entire computation can be skipped.
Setup with \(P = 8\):
Divide sequence into 8 chunks: \(C_0, C_1, \ldots, C_7\)
GPU \(r\) holds queries from chunk \(C_r\). After \(k\) rotations, it receives K, V from chunk \(C_{(r-k) \mod 8}\).
When can we skip?
GPU \(r\) can skip step \(k\) if all tokens in \(C_{(r-k) \mod 8}\) are in the future relative to all tokens in \(C_r\).
This happens when \((r - k) \mod 8 > r\), meaning the KV chunk index is greater than the Q chunk index.
Analysis per GPU:
| GPU (Q chunk) | KV chunks received | Skippable chunks |
|---|---|---|
| GPU 0 (\(C_0\)) | \(C_0 \to C_7 \to C_6 \to \ldots \to C_1\) | \(C_1, C_2, \ldots, C_7\) (7 skippable) |
| GPU 1 (\(C_1\)) | \(C_1 \to C_0 \to C_7 \to \ldots \to C_2\) | \(C_2, C_3, \ldots, C_7\) (6 skippable) |
| GPU 2 (\(C_2\)) | \(C_2 \to C_1 \to C_0 \to \ldots \to C_3\) | \(C_3, C_4, \ldots, C_7\) (5 skippable) |
| GPU 3 (\(C_3\)) | ... | 4 skippable |
| GPU 4 (\(C_4\)) | ... | 3 skippable |
| GPU 5 (\(C_5\)) | ... | 2 skippable |
| GPU 6 (\(C_6\)) | ... | 1 skippable |
| GPU 7 (\(C_7\)) | \(C_7 \to C_6 \to \ldots \to C_0\) | 0 skippable |
Average skippable steps per GPU:
Work reduction:
Without optimization: \(P = 8\) steps per GPU With causal skip: \(8 - 3.5 = 4.5\) steps on average
General formula for \(P\) GPUs:
Average skippable steps:
For \(P = 8\): \(\frac{8-1}{2} = 3.5\) ✓
Important caveat:
The diagonal chunk (\(C_r\) for GPU \(r\)) requires partial computation—tokens within the chunk still have causal relationships. This is handled by applying a triangular mask within the diagonal block.
| Metric | Value |
|---|---|
| Total steps without optimization | 8 |
| Average skippable steps | 3.5 |
| Average required steps | 4.5 |
| Work reduction | 43.75% |
| Theoretical speedup | 1.78× |
- Hybrid design: Design a sequence parallelism strategy for 64 GPUs (8 nodes × 8 GPUs) with \(S = 2M\) tokens. Propose group configurations and estimate memory per GPU.
Solution
Given:
- Total GPUs: 64 (8 nodes × 8 GPUs per node)
- Sequence length: \(S = 2M = 2,097,152\) tokens
- Assume: \(H = 8192\), \(n_h = 64\) heads, bf16
Topology considerations:
- Intra-node: 8 GPUs connected via NVLink (900 GB/s)
- Inter-node: GPUs connected via InfiniBand (~400 Gb/s = 50 GB/s)
Hybrid strategy: Ulysses intra-node + Ring inter-node
| Level | Parallelism | Degree | Communication |
|---|---|---|---|
| Intra-node | Ulysses | 8 | Fast AlltoAll via NVLink |
| Inter-node | Ring | 8 | P2P via InfiniBand |
How it works:
- Divide 64 GPUs into 8 ring groups (one per node position across nodes)
- Each ring group has 8 GPUs across 8 nodes
- Within each node, use Ulysses to redistribute sequence↔heads
- Across nodes, use Ring to rotate KV chunks
Sequence distribution:
Memory estimation per GPU:
Attention matrix memory:
With 64-way SP, attention computed on \((S/64) \times (S/64)\) chunks:
Wait—this doesn't fit in 80GB! We need Flash Attention.
With Flash Attention (no materialized attention matrix):
Q, K, V per GPU:
KV buffers for Ring (double buffering):
Total memory per GPU (with Flash Attention):
| Component | Memory |
|---|---|
| Q, K, V local | 1.6 GB |
| KV ring buffers | 2.1 GB |
| Output | 0.5 GB |
| Flash workspace | ~0.5 GB |
| Total attention | ~5 GB |
This fits easily in 80GB, leaving room for model weights and activations.
Communication analysis:
Ulysses (intra-node):
- AlltoAll volume: \(3 \times (S/64) \times H \times 2 \times 7/8 = 1.4\) GB
- Time at 900 GB/s: ~1.5 ms
Ring (inter-node):
- Per step: \((K+V) = 2 \times (S/8) \times H\) elements → \(2 \times (S/8) \times H \times 2 = 8.6\) GB (bf16)
-
7 steps at 50 GB/s: \(7 \times 8.6 / 50 \approx 1.2\) s
Alternative: Pure Ring with 64-way
- 63 steps of rotating KV
- Much higher latency but simpler
- Works if compute dominates
Recommended configuration:
Intra-node: Ulysses (P=8) - Fast AlltoAll on NVLink - Redistributes S/8 → S, H → H/8 Inter-node: Ring (P=8) - Overlapped P2P on InfiniBand - Each node rotates as a unit Memory per GPU: ~5 GB for attention Sequence per GPU: 32K tokens Total throughput: 2M token contextConfiguration Pros Cons Hybrid Ulysses+Ring Best of both worlds Complex implementation Pure Ring-64 Simple 63 steps, high latency Pure Ulysses-64 Low volume AlltoAll across nodes is slow -
Online softmax verification: Implement and verify that combining three chunks \((A, B, C)\) as \((A \oplus B) \oplus C\) gives the same result as \(A \oplus (B \oplus C)\) for the attention operation.
Solution
Online softmax state representation:
For each chunk, we track a tuple \((m, s, o)\): - \(m\): maximum logit seen so far - \(s\): sum of exponentials (scaled by current max) - \(o\): weighted output (scaled by current max)
Combination operator \(\oplus\):
Given states \((m_1, s_1, o_1)\) and \((m_2, s_2, o_2)\):
Implementation:
import torch
import numpy as np
def create_chunk_state(logits, values):
"""
Compute (m, s, o) state for a chunk.
Args:
logits: [seq_q, seq_kv] attention logits
values: [seq_kv, d] value vectors
Returns:
m: [seq_q] max logits
s: [seq_q] sum of exp(logits - m)
o: [seq_q, d] weighted sum of values
"""
m = logits.max(dim=-1, keepdim=True).values # [seq_q, 1]
exp_logits = torch.exp(logits - m) # [seq_q, seq_kv]
s = exp_logits.sum(dim=-1, keepdim=True) # [seq_q, 1]
o = exp_logits @ values # [seq_q, d]
return m.squeeze(-1), s.squeeze(-1), o
def combine_states(state1, state2):
"""
Combine two (m, s, o) states using the associative operator.
"""
m1, s1, o1 = state1
m2, s2, o2 = state2
m_new = torch.maximum(m1, m2)
# Correction factors
alpha1 = torch.exp(m1 - m_new).unsqueeze(-1)
alpha2 = torch.exp(m2 - m_new).unsqueeze(-1)
s_new = s1 * alpha1.squeeze(-1) + s2 * alpha2.squeeze(-1)
o_new = o1 * alpha1 + o2 * alpha2
return m_new, s_new, o_new
def finalize_output(state):
"""Convert (m, s, o) state to final attention output."""
m, s, o = state
return o / s.unsqueeze(-1)
# Test associativity
torch.manual_seed(42)
seq_q, seq_kv, d = 4, 6, 8
# Create Q, K, V
Q = torch.randn(seq_q, d)
K = torch.randn(seq_kv, d)
V = torch.randn(seq_kv, d)
# Split K, V into 3 chunks
chunk_size = seq_kv // 3
K_chunks = [K[i*chunk_size:(i+1)*chunk_size] for i in range(3)]
V_chunks = [V[i*chunk_size:(i+1)*chunk_size] for i in range(3)]
# Compute logits for each chunk
logits_A = Q @ K_chunks[0].T
logits_B = Q @ K_chunks[1].T
logits_C = Q @ K_chunks[2].T
# Create states for each chunk
state_A = create_chunk_state(logits_A, V_chunks[0])
state_B = create_chunk_state(logits_B, V_chunks[1])
state_C = create_chunk_state(logits_C, V_chunks[2])
# Test (A ⊕ B) ⊕ C
state_AB = combine_states(state_A, state_B)
state_AB_C = combine_states(state_AB, state_C)
output_left = finalize_output(state_AB_C)
# Test A ⊕ (B ⊕ C)
state_BC = combine_states(state_B, state_C)
state_A_BC = combine_states(state_A, state_BC)
output_right = finalize_output(state_A_BC)
# Ground truth: standard attention
logits_full = Q @ K.T
attn_weights = torch.softmax(logits_full, dim=-1)
output_standard = attn_weights @ V
# Verify
print("Max diff (A⊕B)⊕C vs A⊕(B⊕C):",
(output_left - output_right).abs().max().item())
print("Max diff vs standard attention:",
(output_left - output_standard).abs().max().item())
Output:
Verification results:
| Comparison | Max Absolute Difference |
|---|---|
| \((A \oplus B) \oplus C\) vs \(A \oplus (B \oplus C)\) | \(\boxed{0.0}\) (exact) |
| Chunked vs Standard | ~\(10^{-7}\) (numerical precision) |
Why it works (mathematical proof):
The combination operator is associative because:
-
Max is associative: \(\max(\max(a, b), c) = \max(a, \max(b, c))\)
-
Weighted sums combine correctly: The exponential rescaling ensures sums are always normalized to the global maximum.
-
Softmax decomposition: For any partition of keys:
$\(\text{softmax}([x_A; x_B; x_C]) = \frac{[e^{x_A}; e^{x_B}; e^{x_C}]}{e^{m_A}s_A + e^{m_B}s_B + e^{m_C}s_C}\)$
This can be computed incrementally by tracking \((m, s, o)\) and combining associatively.
- Flash + Ring: Modify the Ring Attention algorithm to use Flash Attention internally. What are the memory implications of the nested tiling?
Solution
Flash Attention inside Ring Attention:
Ring Attention and Flash Attention both use online softmax, making them naturally composable:
- Ring: Tiles across GPUs (distributed KV chunks)
- Flash: Tiles within GPU (SRAM-sized blocks)
Nested tiling structure:
Ring Level (distributed):
For each KV chunk from ring:
Flash Level (local):
For each Q block that fits in SRAM:
For each KV block from current chunk:
Compute partial attention
Update (m, s, o) state
Implementation:
import torch
import torch.distributed as dist
def flash_attention_forward(q, k, v, block_size=256):
"""
Flash Attention with online softmax.
Returns (output, m, lse) for Ring integration.
Args:
q: [seq_q, d]
k: [seq_kv, d]
v: [seq_kv, d]
Returns:
output: [seq_q, d]
m: [seq_q] max logits per query
lse: [seq_q] log-sum-exp per query
"""
seq_q, d = q.shape
seq_kv = k.shape[0]
# Initialize online state
m = torch.full((seq_q,), float('-inf'), device=q.device)
lse = torch.zeros(seq_q, device=q.device)
output = torch.zeros(seq_q, d, device=q.device)
# Process K, V in blocks (simulating SRAM tiling)
for j in range(0, seq_kv, block_size):
k_block = k[j:j+block_size] # [block, d]
v_block = v[j:j+block_size] # [block, d]
# Compute attention scores for this block
scores = q @ k_block.T # [seq_q, block]
# Online softmax update
m_block = scores.max(dim=-1).values # [seq_q]
m_new = torch.maximum(m, m_block)
# Rescale previous accumulator
alpha = torch.exp(m - m_new)
lse = lse * alpha
# Add new contributions
exp_scores = torch.exp(scores - m_new.unsqueeze(-1))
lse = lse + exp_scores.sum(dim=-1)
# Update output
output = output * alpha.unsqueeze(-1)
output = output + exp_scores @ v_block
m = m_new
return output, m, lse
def ring_flash_attention(q_local, k_local, v_local, group):
"""
Ring Attention using Flash Attention as inner kernel.
Args:
q_local: Local Q chunk [seq_local, d]
k_local: Local K chunk [seq_local, d]
v_local: Local V chunk [seq_local, d]
group: Process group for ring communication
"""
rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
device = q_local.device
seq_local, d = q_local.shape
# Initialize ring state (using online softmax representation)
m_global = torch.full((seq_local,), float('-inf'), device=device)
lse_global = torch.zeros(seq_local, device=device)
output_global = torch.zeros(seq_local, d, device=device)
# Double buffering for KV
k_recv = torch.empty_like(k_local)
v_recv = torch.empty_like(v_local)
k_curr, v_curr = k_local, v_local
for step in range(world_size):
# Async receive next KV (except last step)
if step < world_size - 1:
src = (rank - 1) % world_size
dst = (rank + 1) % world_size
recv_k = dist.irecv(k_recv, src=src, group=group)
recv_v = dist.irecv(v_recv, src=src, group=group)
send_k = dist.isend(k_curr, dst=dst, group=group)
send_v = dist.isend(v_curr, dst=dst, group=group)
# Flash Attention on current KV chunk
output_chunk, m_chunk, lse_chunk = flash_attention_forward(
q_local, k_curr, v_curr
)
# Combine with global state (online softmax merge)
m_new = torch.maximum(m_global, m_chunk)
alpha_global = torch.exp(m_global - m_new)
alpha_chunk = torch.exp(m_chunk - m_new)
lse_global = lse_global * alpha_global + lse_chunk * alpha_chunk
output_global = (output_global * alpha_global.unsqueeze(-1) +
output_chunk * alpha_chunk.unsqueeze(-1))
m_global = m_new
# Wait for communication and swap buffers
if step < world_size - 1:
recv_k.wait()
recv_v.wait()
send_k.wait()
send_v.wait()
k_curr, k_recv = k_recv, k_curr
v_curr, v_recv = v_recv, v_curr
# Final normalization
output_final = output_global / lse_global.unsqueeze(-1)
return output_final
Memory analysis:
| Component | Standard Ring | Flash + Ring |
|---|---|---|
| Attention matrix | \((S/P)^2\) per chunk | 0 (never materialized) |
| Q per GPU | \(S/P \times d\) | \(S/P \times d\) |
| K, V buffers | \(4 \times S/P \times d\) | \(4 \times S/P \times d\) |
| Flash workspace | 0 | \(O(\text{block\_size} \times d)\) |
| Output accumulator | \(S/P \times d\) | \(S/P \times d\) |
Memory implications of nested tiling:
For $S = 1M$, $P = 8$, $d = 128$, $B = 256$:
- Q, K, V local: $3 \times 128K \times 128 \times 2 = 98$ MB
- KV buffers: $2 \times 128K \times 128 \times 2 = 65$ MB
- Flash workspace: ~1 MB
- Output: $128K \times 128 \times 2 = 33$ MB
**Total: ~200 MB** (vs 4+ GB without Flash)
**Key benefits:**
| Benefit | Explanation |
|---------|-------------|
| $O(S/P)$ memory | No $(S/P)^2$ attention matrix |
| IO efficiency | Flash reduces HBM traffic |
| Composability | Both use online softmax |
| Numerical stability | Log-sum-exp trick throughout |
**Memory reduction:**
Without Flash: $M = O((S/P)^2)$ for attention matrix
With Flash: $M = O(S/P)$ linear in local sequence
For $S/P = 128K$:
- Without Flash: $128K^2 \times 2 = 32$ GB (doesn't fit!)
- With Flash: ~200 MB
$$\text{Reduction} = \frac{(S/P)^2}{S/P \cdot d} = \frac{S/P}{d} = \frac{128K}{128} = \boxed{1024\times}$$
Knobs and Trade-offs¶
| Knob | Primary Effect | Cost |
|---|---|---|
| Sequence-parallel degree | Reduces activation memory | More AllGather/ReduceScatter |
| Ring vs AlltoAll (context) | Better overlap vs simplicity | Topology sensitivity |
| Flash + tiling | Lower memory and IO | More kernel complexity |
| Causal masking shortcuts | Fewer chunks | More edge-case logic |
Key Takeaways¶
-
Two types of sequence parallelism:
-
Megatron-style: Reduces LayerNorm/Dropout activation memory
-
Context parallelism: Distributes attention computation itself
-
Online softmax enables decomposition: Attention can be computed incrementally with associative state updates.
-
Ring Attention: Rotate K, V in a ring; communication overlaps with compute.
-
Ulysses: AlltoAll to redistribute sequence vs heads; simpler but less overlap.
-
Memory scales as \(O((S/P)^2)\): Ring Attention reduces attention matrix by factor \(P^2\).
-
Flash Attention integration: Use Flash as the inner kernel for memory efficiency.
-
Causal masking optimization: Skip chunks where all keys are in the future.
-
Choose based on topology: Ring for many ranks with good P2P; Ulysses for few ranks with fast AlltoAll.