10 Sparsity
The Art of Not Computing
The most powerful optimization: don’t do the work at all.
A matrix is 90% zeros. Shouldn’t we be 10× faster?
Usually not. And understanding why reveals deep truths about hardware.
The systematic study of sparse computation began with solving large systems of linear equations arising from physical simulations. Engineers noticed that most coefficients were zero—the matrix structure mirrored the physical locality of interactions.
The key insight: exploiting sparsity requires structure, not just zeros. Random sparsity doesn’t help; structured sparsity (banded, block-diagonal, graph-based) enables specialized algorithms. This lesson applies directly to modern ML: structured sparsity (e.g., 2:4 patterns for tensor cores) wins where random sparsity fails.
10.1 The Sparsity Paradox
Sparsity is everywhere:
- Neural network weights after pruning
- Attention patterns (most tokens don’t attend to most other tokens)
- Graph adjacency matrices
- Recommender systems (users rate few items)
- Gradient updates (many parameters barely change)
Mathematically, 90% sparsity means 90% fewer operations. But in practice:
| Sparsity | Theoretical Speedup | Actual Speedup |
|---|---|---|
| 50% | 2× | 1.1× |
| 90% | 10× | 1.8× |
| 95% | 20× | 2.3× |
| 99% | 100× | 3.5× |
The gap is enormous. Where does the performance go?
10.2 Why Hardware Hates Random Sparsity
Modern hardware is built for predictable computation. Sparse computation is unpredictable.
10.2.1 The Memory Access Problem
Dense matrix multiply:
A[i,j] = B[i,:] @ C[:,j]
= B[i,0]*C[0,j] + B[i,1]*C[1,j] + ... + B[i,k]*C[k,j]
Memory access pattern: sequential along rows and columns. Predictable. Prefetchable.
Sparse matrix multiply:
A[i,j] = sum(B[i,k] * C[k,j] for k where B[i,k] != 0)
Memory access pattern: wherever the non-zeros happen to be. Random. Unprefetchable.
Dense access pattern: Sparse access pattern:
┌─────────────────────┐ ┌─────────────────────┐
│█ █ █ █ █ █ █ █ █ █ │ │█ █ █ │
│█ █ █ █ █ █ █ █ █ █ │ │ █ █ │
│█ █ █ █ █ █ █ █ █ █ │ │ █ █ │
│█ █ █ █ █ █ █ █ █ █ │ │ █ █ │
└─────────────────────┘ └─────────────────────┘
Sequential → Fast Random → Slow
10.2.2 The Indexing Overhead
Dense operations don’t need indices. Element (i, j) is at base + i*stride + j.
Sparse formats need indices:
# CSR format (Compressed Sparse Row)
class CSRMatrix:
def __init__(self):
self.data = [...] # Non-zero values
self.indices = [...] # Column index of each value
self.indptr = [...] # Where each row starts
# For a 1000×1000 matrix with 1% non-zeros:
# Dense: 1,000,000 values
# Sparse: 10,000 values + 10,000 column indices + 1,001 row pointers
# = 21,001 numbers (but we also have indexing overhead)The indexing itself requires memory accesses and compute.
10.2.3 The Parallelism Problem
GPUs execute in lockstep (SIMT). In a warp of 32 threads:
Dense: All threads do the same operation on different data
Thread 0: A[0,0] * B[0,0]
Thread 1: A[0,1] * B[1,0]
...
Thread 31: A[0,31] * B[31,0]
→ All threads active
Sparse: Each thread might have different work (or no work)
Thread 0: A[0,5] * B[5,0] (if A[0,5] != 0)
Thread 1: skip (if A[0,1] == 0)
Thread 2: A[0,7] * B[7,0] (if A[0,7] != 0)
...
→ Many threads idle
Sparse patterns cause warp divergence and load imbalance.
10.3 Structured Sparsity: The Compromise
Random sparsity is hard. Structured sparsity is tractable.
The key insight: constrain where zeros can appear, so hardware can predict and optimize.
10.3.1 Block Sparsity
Require zeros to occur in blocks:
Random sparsity: Block sparsity (4×4 blocks):
┌─────────────────────┐ ┌─────────────────────┐
│█ █ █ │ │████ ████ │
│ █ █ │ │████ ████ │
│ █ █ │ │████ ████ │
│ █ █ │ │████ ████ │
│ █ █ │ │ │
│ █ █ │ │ │
│ █ █ │ │ ████████ │
│ █ █ │ │ ████████ │
└─────────────────────┘ └─────────────────────┘
With block sparsity: - Memory access is contiguous within blocks - No indexing within blocks - Load balancing is easier (work per block is uniform)
def block_sparse_matmul(A_blocks, A_indices, B, block_size=64):
"""
A is stored as a list of dense blocks + their positions.
"""
result = torch.zeros(A.shape[0], B.shape[1])
for block, (i, j) in zip(A_blocks, A_indices):
# Each block is a dense tile
row_start = i * block_size
row_end = row_start + block_size
col_start = j * block_size
col_end = col_start + block_size
# Dense matmul for this block
result[row_start:row_end] += block @ B[col_start:col_end]
return resultBlock sparsity achieves ~80% of theoretical speedup (vs. ~20% for random sparsity at same density).
10.3.2 2:4 Structured Sparsity
NVIDIA’s Ampere GPUs introduced hardware support for “2:4” sparsity:
Constraint: In every group of 4 consecutive elements, exactly 2 must be zero.
Valid 2:4 patterns (█ = nonzero, ░ = zero):
██░░ █░█░ █░░█ ░██░ ░█░█ ░░██
Invalid:
███░ (only 1 zero)
█░░░ (3 zeros)
░░░░ (4 zeros)
This gives exactly 50% sparsity, with hardware acceleration:
Standard FP16: Tensor Core throughput: 312 TFLOPS
2:4 Sparse FP16: Tensor Core throughput: 624 TFLOPS (2× faster)
The hardware knows the pattern: 2 non-zeros per 4 elements. It can skip the zeros efficiently.
# PyTorch 2.0+ supports 2:4 sparsity
import torch
from torch.sparse import to_sparse_semi_structured
# Convert a dense tensor to 2:4 sparse
dense = torch.randn(1024, 1024, device='cuda', dtype=torch.float16)
sparse_24 = to_sparse_semi_structured(dense)
# Matmul uses accelerated 2:4 kernels
result = torch.mm(sparse_24, other_matrix) # ~2× faster on Ampere+10.3.3 The Accuracy Trade-off
Forcing structure into sparsity costs accuracy:
| Sparsity Type | Achievable Sparsity | Accuracy Drop (ResNet-50) |
|---|---|---|
| Unstructured | 90% | 1-2% |
| Block (16×16) | 75% | 1-2% |
| Block (64×64) | 50% | 1-2% |
| 2:4 Structured | 50% | <1% |
Unstructured sparsity is most flexible but least accelerable. 2:4 is most accelerable but constrains sparsity to exactly 50%.
10.4 Investigation: Mixture of Experts
What if, instead of making weights sparse, we made activation of weights conditional?
Mixture of Experts (MoE): Different inputs use different parts of the network.
10.4.1 The Architecture
Standard feed-forward:
Input → [FFN] → Output
Every token uses the full FFN
MoE feed-forward:
Input → [Router] → selects Expert(s)
→ [Expert 1] ─┐
→ [Expert 2] ├→ Weighted sum → Output
→ [Expert 3] │
→ ... │
→ [Expert N] ─┘
The router is a small network that decides which expert(s) to use for each input.
class MoELayer(nn.Module):
def __init__(self, dim, num_experts=8, top_k=2):
super().__init__()
self.experts = nn.ModuleList([
nn.Linear(dim, dim * 4) # Each expert is an FFN
for _ in range(num_experts)
])
self.gate = nn.Linear(dim, num_experts)
self.top_k = top_k
def forward(self, x):
# x: [batch, seq, dim]
batch, seq, dim = x.shape
# Compute routing probabilities
router_logits = self.gate(x) # [batch, seq, num_experts]
router_probs = F.softmax(router_logits, dim=-1)
# Select top-k experts per token
top_k_probs, top_k_indices = router_probs.topk(self.top_k, dim=-1)
# Renormalize
top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
# Compute output (simplified, actual impl batches by expert)
output = torch.zeros_like(x)
for k in range(self.top_k):
expert_idx = top_k_indices[:, :, k]
weight = top_k_probs[:, :, k:k+1]
for e in range(len(self.experts)):
mask = (expert_idx == e)
if mask.any():
expert_out = self.experts[e](x[mask])
output[mask] += weight[mask] * expert_out
return output10.4.2 The Efficiency Win
Mixtral 8x7B: - 8 experts, uses top-2 per token - 46.7B total parameters - 12.9B active parameters per token (2/8 × 46.7B + overhead) - Performs like a ~70B dense model
Parameter efficiency:
Dense 70B: 70B params active, 70B params total
Mixtral: 12.9B params active, 46.7B params total
Active params: 5.4× fewer
Total params: 0.67× as many
Performance: Similar to dense 70B
MoE is a form of conditional computation: the model dynamically decides what work to do.
10.4.3 The Challenges
MoE isn’t free:
Load balancing: If all tokens route to the same expert, that expert becomes a bottleneck. Needs auxiliary losses to encourage balance.
Memory: All experts must be loaded, even if only 2 are active per token. 46.7B parameters in memory for 12.9B active.
Communication: In distributed training, experts live on different GPUs. Routing requires all-to-all communication.
Inference latency: For single-token generation, only 2 experts are used—but all must be resident. Memory bandwidth, not compute, often dominates.
The MoE memory paradox:
Throughput (batch processing): MoE wins
- Tokens spread across experts
- Effective compute scales with batch
Latency (single token): Dense wins
- All params in memory, few active
- Memory-bound: loading 46.7B to use 12.9B
10.5 Sparsity in Attention
Attention is naturally sparse. Most tokens don’t strongly attend to most other tokens.
10.5.1 Empirical Attention Patterns
Analysis of trained transformers shows:
Typical attention pattern (one head):
Query position →
1 2 3 4 5 6 7 8 9 10 11 12
┌────────────────────────────────────┐
1 │ █ │
2 │ █ █ │
3 │ █ ░ █ │
4 │ █ ░ ░ █ │ █ = high attention
5 │ █ ░ ░ ░ █ │ ░ = low attention
6 │ █ ░ ░ ░ ░ █ │ (blank = ~zero)
7 │ █ ░ ░ ░ ░ ░ █ │
8 │ █ ░ ░ ░ ░ ░ ░ █ │
9 │ █ ░ ░ ░ ░ ░ ░ ░ █ │
10│ █ ░ ░ ░ ░ ░ ░ ░ ░ █ │
11│ █ ░ ░ ░ ░ ░ ░ ░ ░ ░ █ │
12│ █ ░ ░ ░ ░ ░ ░ ░ ░ ░ ░ █ │
└────────────────────────────────────┘
Key ↓
Pattern: Attend to (1) self, (2) first token, (3) nearby tokens
Most of the matrix is near-zero.
This suggests we could skip most attention computations.
10.5.2 Sparse Attention Variants
Sliding window: Only attend within a local window.
def sliding_window_attention(Q, K, V, window_size=256):
seq_len = Q.shape[0]
output = torch.zeros_like(Q)
for i in range(seq_len):
start = max(0, i - window_size)
end = i + 1 # Causal: only attend to past
q = Q[i:i+1]
k = K[start:end]
v = V[start:end]
attn = F.softmax(q @ k.T / sqrt(d), dim=-1)
output[i] = attn @ v
return outputComplexity: O(n × window) instead of O(n²)
Sparse patterns: Combine local + global + random.
Longformer/BigBird patterns: - Local window attention (nearby tokens) - Global tokens (attend to/from everywhere) - Random attention (a few random connections)
Sparse attention pattern:
Global Local Random
↓ ↓ ↓
████████████████████████ ← Global tokens (attend everywhere)
████████████████████████
██████ ← Local window
██████
██████
██████ █ ← Random connections
██████
██████ █
10.5.3 The FlashAttention Insight
FlashAttention (Associativity, FlashAttention) takes a different approach: don’t make attention sparse, but compute it exactly with better memory access.
For many applications, FlashAttention + dense attention beats sparse attention: - No approximation error - Simpler implementation - Hardware-friendly access patterns
Sparse attention shines for very long sequences (>16K tokens) where even FlashAttention’s O(n²) compute is prohibitive.
10.6 From Sparsity to Structure
Here’s a counterintuitive insight: sometimes the best alternative to a sparse matrix isn’t a denser sparse matrix—it’s a structured dense matrix.
10.6.1 The Core Insight
Random sparsity fails because hardware can’t predict the pattern. But what if we replace a large matrix with a smaller structured computation that achieves similar expressiveness?
Approach 1: Sparse matrix
- 1000×1000 matrix, 90% zeros
- 100K non-zeros + indices
- Random access pattern → slow
Approach 2: Structured matrix
- Two 1000×32 matrices (UV^T, rank-32)
- 64K values, no indices
- Sequential access pattern → fast
- Often more expressive than random sparsity!
The key: structure enables algorithmic speedups, not just zero-skipping.
10.6.2 Structured Matrix Families
Several matrix families have structure that enables efficient multiplication:
10.6.2.1 Low-Rank Matrices (Separability)
Any matrix can be approximated as \(A \approx UV^T\) where \(U\) and \(V\) are thin matrices.
def low_rank_multiply(U, V, x):
"""
Compute (UV^T)x efficiently.
A is n×n, but stored as U (n×r) and V (n×r)
"""
# Instead of: A @ x → O(n²)
# Do: U @ (V.T @ x) → O(nr)
return U @ (V.T @ x)This is the separability property (see Chapter: Separability). LoRA exploits exactly this for efficient fine-tuning.
When does low-rank work? When the matrix has structure—correlations between rows or columns. Neural network weight matrices often have low effective rank after training.
10.6.2.2 Butterfly Matrices
Butterfly matrices factor a matrix into \(O(\log n)\) sparse structured layers:
n×n butterfly matrix = Product of log(n) sparse factors
Each factor:
- Block-diagonal structure
- 2×2 blocks mixed with permutations
- Only O(n) non-zeros per factor
Total: O(n log n) multiply instead of O(n²)
def butterfly_multiply(factors, x):
"""
Multiply by a butterfly-factored matrix.
factors: list of log(n) sparse block-diagonal matrices
"""
result = x
for factor in factors:
# Each factor is block-diagonal with 2×2 blocks
# O(n) operations per factor
result = factor @ result
return result
# Total: O(n log n)The FFT is a butterfly matrix! Butterfly structure appears naturally in: - Signal processing (FFT, DCT) - Efficient linear layers (Monarch matrices) - Replacing attention heads (Hyena, M2-BERT)
10.6.2.3 Monarch Matrices
Monarch matrices use a clever factorization:
\[M = P_2 \cdot (L_1 \odot L_2) \cdot P_1\]
where \(L_1, L_2\) are block-diagonal and \(P_1, P_2\) are permutations.
Monarch matrix structure:
1. Permute input (P₁)
2. Block-diagonal multiply (parallel small matrices)
3. Element-wise multiply (L₁ ⊙ L₂)
4. Permute output (P₂)
Complexity: O(n√n) vs O(n²) for dense
Hardware: Block-diagonal = batched small GEMMs = fast
Monarch matrices achieve: - 2× speedup over dense for language model FFN layers - Minimal quality loss (<1% perplexity) - Hardware-friendly operations (block GEMMs, no irregular access)
10.6.2.4 Circulant and Toeplitz Matrices
Matrices with shift structure can use FFT:
def circulant_multiply(first_column, x):
"""
Multiply by circulant matrix using FFT.
Circulant: each row is previous row shifted by 1
"""
# Circulant = F^{-1} diag(F c) F
# where F is the DFT matrix, c is first column
c_fft = torch.fft.fft(first_column)
x_fft = torch.fft.fft(x)
result_fft = c_fft * x_fft
return torch.fft.ifft(result_fft).real
# O(n log n) instead of O(n²)Circulant structure appears in convolutions (which is why FFT-based convolution works).
10.6.3 When Structure Beats Sparsity
Here’s a practical comparison:
| Matrix Type | Size | Storage | Multiply Cost | Hardware Efficiency |
|---|---|---|---|---|
| Dense | n×n | n² | O(n²) | 100% (baseline) |
| Random 90% sparse | n×n | ~0.2n² | O(0.1n²) theory | 20-30% of theoretical |
| Block sparse (50%) | n×n | 0.5n² | O(0.5n²) | 70-80% of theoretical |
| 2:4 structured | n×n | 0.5n² | O(0.5n²) | 100% (hardware support) |
| Low-rank (r=32) | n×n | 64n | O(64n) | 100% (dense GEMMs) |
| Butterfly | n×n | O(n log n) | O(n log n) | ~80% (structured sparse) |
| Monarch | n×n | O(n√n) | O(n√n) | ~90% (block GEMMs) |
The pattern: Structured matrices achieve their theoretical speedup because hardware can predict and optimize the access pattern.
10.6.4 Structured FFN Layers
Feed-forward layers are 2/3 of transformer parameters. Structured replacements are compelling:
class MonarchFFN(nn.Module):
"""
FFN layer using Monarch matrix factorization.
Standard FFN: 4n² parameters
Monarch FFN: O(n√n) parameters, similar expressiveness
"""
def __init__(self, dim, block_size=32):
super().__init__()
n_blocks = dim // block_size
# Block-diagonal matrices
self.blocks_up = nn.Parameter(
torch.randn(n_blocks, block_size, block_size * 4)
)
self.blocks_down = nn.Parameter(
torch.randn(n_blocks, block_size * 4, block_size)
)
# Permutations (can be learned or fixed)
self.register_buffer('perm',
torch.randperm(dim).reshape(n_blocks, block_size))
def forward(self, x):
batch, seq, dim = x.shape
n_blocks = self.blocks_up.shape[0]
block_size = dim // n_blocks
# Reshape for block operations
x = x.view(batch, seq, n_blocks, block_size)
# Apply permutation + block diagonal up-projection
x = x[:, :, self.perm // block_size, self.perm % block_size]
x = torch.einsum('bsnb,nbo->bsno', x, self.blocks_up)
# Activation
x = F.gelu(x)
# Block diagonal down-projection + inverse permutation
x = torch.einsum('bsno,nob->bsnb', x, self.blocks_down)
return x.view(batch, seq, dim)10.6.5 The Separability Connection
Structured matrices are deeply connected to separability (from Factoring):
| Technique | Factorization | Separability Pattern |
|---|---|---|
| Low-rank | \(A = UV^T\) | Separates rows from columns |
| Butterfly | \(A = B_1 B_2 \cdots B_k\) | Separates scales (like FFT) |
| Monarch | \(A = P_2 (L_1 \odot L_2) P_1\) | Separates blocks + permutation |
| Tensor decomposition | \(A = \sum_r a_r \otimes b_r\) | Separates dimensions |
The meta-lesson: when sparse doesn’t work, look for separable structure.
10.6.6 When to Use Each Approach
Problem Best Approach
────────────────────────────────────────────────────
Natural sparsity (pruned nets) 2:4 structured sparsity
Conditional compute MoE (dynamic sparsity)
Compressing large weights Low-rank (LoRA, SVD)
Replacing dense layers Monarch or butterfly
Convolution-like operations Circulant (FFT)
Long-range dependencies Structured attention (not sparse)
Sparsity is about skipping zeros. Structure is about exploiting patterns.
Random zeros give no pattern to exploit. Structured matrices—even fully dense ones—give patterns that enable algorithmic shortcuts.
This is why 2:4 sparsity works (hardware knows the pattern) and why Monarch matrices can be faster than 90% sparse matrices (predictable block structure beats unpredictable random access).
The question isn’t “how many zeros?” but “what structure enables efficient algorithms?”
10.7 The Sparsity-Hardware Co-Design
The lesson from sparsity: algorithms and hardware must be designed together.
| Algorithm Wants | Hardware Needs | Compromise |
|---|---|---|
| Skip arbitrary zeros | Predictable patterns | Block or 2:4 sparsity |
| Prune anywhere | Load balance | Structured pruning |
| Different paths/token | Same instruction/warp | MoE with batched routing |
| Sparse attention | Dense tensor ops | Fixed sparse patterns |
The pattern: constrain flexibility for predictability.
10.8 Key Takeaways
Random sparsity doesn’t help: Hardware can’t exploit unpredictable skip patterns. 90% sparsity often gives only 2× speedup.
Structure enables acceleration: Block sparsity, 2:4 sparsity, and MoE all trade flexibility for predictability—and hardware rewards that.
MoE is conditional computation: Instead of sparse weights, choose which weights to use. Scales capacity without scaling compute.
Attention is empirically sparse: But exploiting it is hard. Dense FlashAttention often beats sparse approximations.
Structured matrices beat random sparsity: Butterfly, Monarch, and low-rank matrices achieve theoretical speedups because their patterns are predictable. The question isn’t “how many zeros?” but “what structure enables efficient algorithms?”
Co-design is essential: The best sparsity pattern depends on your hardware. NVIDIA 2:4 only matters on Ampere+.
10.9 The Meta-Lesson
All three optimization properties we’ve seen—associativity (chunking), separability (factoring), and sparsity (skipping)—share a theme:
Find structure that aligns with hardware.
- Associativity aligns with parallel reduction
- Separability aligns with GEMM efficiency
- Structured sparsity aligns with predictable memory access
The property tells you what’s mathematically possible. The hardware tells you what’s practically fast.
10.10 Further Reading
- Hoefler et al. (2021). “Sparsity in Deep Learning: Pruning and growth for efficient inference and training in neural networks”
- Fedus et al. (2021). “Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity”
- Mishra et al. (2021). “Accelerating Sparse Deep Neural Networks” (NVIDIA 2:4 sparsity)
- Beltagy et al. (2020). “Longformer: The Long-Document Transformer”
- Dao et al. (2022). “Monarch: Expressive Structured Matrices for Efficient and Accurate Training”
- Dao et al. (2019). “Learning Fast Algorithms for Linear Transforms Using Butterfly Factorizations”
- Fu et al. (2023). “Hungry Hungry Hippos: Towards Language Modeling with State Space Models”