36 Capstone: The Framework in Your Hands
Three Problems, Six Properties, One Method
You’ve learned the properties. You’ve studied the case studies.
Now we face three problems you haven’t seen optimized before — spanning ML, systems, and scientific computing. For each, we apply the framework from first principles and see which properties unlock the optimization.
The goal isn’t to memorize these solutions. It’s to internalize the process so you can do this yourself.
36.1 The Method
Every investigation follows the same loop:
- Understand the computation
- Profile to find bottlenecks
- Audit properties — systematically check all six
- Derive the optimization from the matching property
- Implement and measure
- Iterate until satisfied
We’ll apply this to three problems of increasing breadth:
| Problem | Domain | Primary Properties | Key Insight |
|---|---|---|---|
| Mamba scan | ML inference | Associativity + Locality | Sequential recurrence → parallel scan |
| Streaming top-K | Data analytics | Associativity + Redundancy | Exact ranking from partial heaps |
| N-body simulation | Scientific computing | Separability + Locality | Far-field approximation via multipole |
36.2 Problem 1: Mamba Scan (ML Inference)
This problem applies everything we’ve learned to optimizing Mamba-style state space models for inference.
36.3 Understanding Mamba
Mamba [1] is a state space model (SSM) that offers an alternative to attention. Its key advantage: linear complexity in sequence length, compared to attention’s quadratic.
36.3.1 The Core Recurrence
At its heart, Mamba computes a recurrence:
def mamba_scan_naive(x, A, B, C, delta):
"""
x: (batch, seq_len, d_model) - input sequence
A: (d_model, d_state) - state transition (discretized)
B: (batch, seq_len, d_state) - input projection
C: (batch, seq_len, d_state) - output projection
delta: (batch, seq_len, d_model) - time step
Returns: (batch, seq_len, d_model) - output
"""
batch, seq_len, d_model = x.shape
d_state = A.shape[1]
# Discretize A using delta
# A_bar = exp(delta * A)
h = torch.zeros(batch, d_model, d_state) # Hidden state
outputs = []
for t in range(seq_len):
# Update hidden state: h = A_bar * h + B * x
A_bar = torch.exp(delta[:, t, :, None] * A) # (batch, d_model, d_state)
h = A_bar * h + B[:, t, None, :] * x[:, t, :, None]
# Output: y = C * h
y = (C[:, t, None, :] * h).sum(dim=-1) # (batch, d_model)
outputs.append(y)
return torch.stack(outputs, dim=1)36.3.2 The Problem
This sequential scan is slow. For a sequence of length N: - N sequential steps (no parallelism) - Each step has O(d_model × d_state) operations - Total: O(N × d_model × d_state)
The sequential nature is the bottleneck. Can we do better?
36.4 Step 1: Profile the Baseline
Before optimizing, measure:
import torch
import time
def benchmark_mamba(seq_len, d_model=2048, d_state=16, batch=1):
x = torch.randn(batch, seq_len, d_model, device='cuda')
A = torch.randn(d_model, d_state, device='cuda')
B = torch.randn(batch, seq_len, d_state, device='cuda')
C = torch.randn(batch, seq_len, d_state, device='cuda')
delta = torch.randn(batch, seq_len, d_model, device='cuda').abs()
# Warmup
for _ in range(3):
_ = mamba_scan_naive(x, A, B, C, delta)
torch.cuda.synchronize()
# Benchmark
start = time.perf_counter()
for _ in range(10):
_ = mamba_scan_naive(x, A, B, C, delta)
torch.cuda.synchronize()
ms = (time.perf_counter() - start) / 10 * 1000
return ms
# Profile across sequence lengths
for seq_len in [256, 512, 1024, 2048, 4096]:
ms = benchmark_mamba(seq_len)
tokens_per_sec = seq_len / (ms / 1000)
print(f"seq_len={seq_len:4d}: {ms:7.2f}ms ({tokens_per_sec:,.0f} tok/s)")Typical output:
seq_len= 256: 12.34ms (20,746 tok/s)
seq_len= 512: 24.56ms (20,847 tok/s)
seq_len=1024: 49.12ms (20,847 tok/s)
seq_len=2048: 98.23ms (20,847 tok/s)
seq_len=4096: 196.45ms (20,847 tok/s)
Linear scaling with sequence length—as expected for sequential scan. But 20K tokens/sec is slow for inference.
36.4.1 Identifying the Bottleneck
Use a profiler to see where time goes:
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA],
record_shapes=True
) as prof:
mamba_scan_naive(x, A, B, C, delta)
print(prof.key_averages().table(sort_by="cuda_time_total"))The profile reveals: - Many small kernel launches (one per timestep) - Low GPU utilization - Memory-bound operations
Diagnosis: The sequential loop prevents parallelism. Each timestep launches new kernels, creating overhead.
36.5 Step 2: Identify Applicable Properties
Let’s check each property from our framework:
36.5.1 Associativity ✓
The recurrence \(h_t = A h_{t-1} + B x_t\) is associative when viewed as matrix operations.
For linear recurrences, we can use the parallel scan algorithm:
Sequential: h₁ → h₂ → h₃ → h₄ → h₅ → h₆ → h₇ → h₈
(8 sequential steps)
Parallel scan:
Step 1: (h₁,h₂) (h₃,h₄) (h₅,h₆) (h₇,h₈) [4 parallel]
Step 2: (h₁..h₄) (h₅..h₈) [2 parallel]
Step 3: (h₁..h₈) [1]
(3 = log₂(8) steps)
This is the same associativity that enables parallel prefix sum!
36.5.2 Locality ✓
The state \(h\) is small (d_model × d_state), but we’re streaming through a long sequence. We can tile the computation to keep working sets in fast memory.
36.5.3 Separability (Partial)
The state dimension d_state is chosen to be small (16-64). This is a design choice exploiting separability—the “effective” dynamics live in a low-dimensional space.
36.5.4 Sparsity ✗
The state transitions are dense. Sparsity doesn’t obviously apply here.
36.5.5 Redundancy (Partial)
During generation, we cache the hidden state h. This is KV-cache style redundancy exploitation.
36.6 Step 3: Derive the Optimization
The key insight is associativity enables parallel scan.
36.6.1 The Parallel Scan Algorithm
Rewrite the recurrence in matrix form:
\[ \begin{pmatrix} h_t \\ 1 \end{pmatrix} = \begin{pmatrix} A_t & B_t x_t \\ 0 & 1 \end{pmatrix} \begin{pmatrix} h_{t-1} \\ 1 \end{pmatrix} \]
Let \(M_t = \begin{pmatrix} A_t & B_t x_t \\ 0 & 1 \end{pmatrix}\).
Then the full computation is:
\[h_N = M_N \cdot M_{N-1} \cdot ... \cdot M_1 \cdot h_0\]
Matrix multiplication is associative! We can compute this product in any order.
36.6.2 Parallel Prefix Product
def parallel_scan_mamba(x, A, B, C, delta):
"""Parallel scan version of Mamba."""
batch, seq_len, d_model = x.shape
d_state = A.shape[1]
# Discretize
A_bar = torch.exp(delta.unsqueeze(-1) * A) # (batch, seq, d_model, d_state)
B_x = B.unsqueeze(2) * x.unsqueeze(-1) # (batch, seq, d_model, d_state)
# Pack into "elements" for parallel scan
# Each element: (A_bar_t, B_x_t)
# Associative op: (A1, B1) ⊕ (A2, B2) = (A2 * A1, A2 * B1 + B2)
# Use associative scan primitive
h = associative_scan(
lambda a, b: (b[0] * a[0], b[0] * a[1] + b[1]),
(A_bar, B_x)
)
# Output projection
y = (C.unsqueeze(2) * h).sum(dim=-1)
return yThe associative_scan operation computes all prefix products in O(log N) parallel steps instead of O(N) sequential steps.
36.6.3 GPU-Friendly Implementation
For GPUs, we need to: 1. Tile the scan to fit in shared memory 2. Fuse the discretization and scan 3. Avoid materializing large intermediates
# Conceptual structure of optimized kernel
@triton.jit
def mamba_scan_kernel(
x_ptr, A_ptr, B_ptr, C_ptr, delta_ptr, output_ptr,
seq_len, d_model, d_state,
BLOCK_SIZE: tl.constexpr
):
# Each block handles a chunk of the sequence
block_id = tl.program_id(0)
# Load chunk into shared memory
# ...
# Parallel scan within block
# ...
# Combine with previous blocks (inter-block scan)
# ...
# Write output
# ...36.7 Step 4: Implementation
The real Mamba implementation uses a custom CUDA kernel. Here’s a simplified version in Triton:
import triton
import triton.language as tl
@triton.jit
def mamba_chunk_scan_kernel(
# Input pointers
x_ptr, A_ptr, B_ptr, C_ptr, delta_ptr,
# Output pointer
y_ptr,
# Dimensions
batch, seq_len, d_model, d_state,
# Strides
stride_x_batch, stride_x_seq, stride_x_d,
# Block size
BLOCK_SEQ: tl.constexpr,
BLOCK_D: tl.constexpr,
):
"""Chunked parallel scan for Mamba."""
# Get program IDs
pid_batch = tl.program_id(0)
pid_d = tl.program_id(1)
# Offset calculations
d_offset = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
d_mask = d_offset < d_model
# Initialize hidden state
h = tl.zeros((BLOCK_D, d_state), dtype=tl.float32)
# Process in chunks
for chunk_start in range(0, seq_len, BLOCK_SEQ):
chunk_end = min(chunk_start + BLOCK_SEQ, seq_len)
# Load chunk (vectorized)
for t in range(chunk_start, chunk_end):
# Load inputs for this timestep
x_offset = pid_batch * stride_x_batch + t * stride_x_seq + d_offset
x_t = tl.load(x_ptr + x_offset, mask=d_mask)
# Load delta, B, C
delta_t = tl.load(delta_ptr + ...)
B_t = tl.load(B_ptr + ...)
C_t = tl.load(C_ptr + ...)
# Discretize A
A_bar = tl.exp(delta_t[:, None] * tl.load(A_ptr + ...))
# Update hidden state
h = A_bar * h + B_t[None, :] * x_t[:, None]
# Compute output
y_t = tl.sum(C_t[None, :] * h, axis=1)
# Store output
tl.store(y_ptr + ..., y_t, mask=d_mask)
# Wrapper
def mamba_scan_optimized(x, A, B, C, delta):
batch, seq_len, d_model = x.shape
d_state = A.shape[1]
y = torch.empty_like(x)
BLOCK_SEQ = 64
BLOCK_D = 64
grid = (batch, triton.cdiv(d_model, BLOCK_D))
mamba_chunk_scan_kernel[grid](
x, A, B, C, delta, y,
batch, seq_len, d_model, d_state,
x.stride(0), x.stride(1), x.stride(2),
BLOCK_SEQ, BLOCK_D
)
return y36.8 Step 5: Measure Improvement
def compare_implementations():
seq_len = 2048
d_model = 2048
d_state = 16
batch = 1
x = torch.randn(batch, seq_len, d_model, device='cuda')
A = torch.randn(d_model, d_state, device='cuda')
B = torch.randn(batch, seq_len, d_state, device='cuda')
C = torch.randn(batch, seq_len, d_state, device='cuda')
delta = torch.randn(batch, seq_len, d_model, device='cuda').abs()
# Benchmark naive
naive_ms = benchmark(lambda: mamba_scan_naive(x, A, B, C, delta))
# Benchmark optimized
opt_ms = benchmark(lambda: mamba_scan_optimized(x, A, B, C, delta))
print(f"Naive: {naive_ms:.2f}ms")
print(f"Optimized: {opt_ms:.2f}ms")
print(f"Speedup: {naive_ms/opt_ms:.1f}x")
compare_implementations()Expected results:
Naive: 98.23ms
Optimized: 6.42ms
Speedup: 15.3x
36.9 Step 6: Iterate
The first optimization rarely captures all the performance. Let’s profile again:
with torch.profiler.profile(...) as prof:
mamba_scan_optimized(x, A, B, C, delta)New bottlenecks might appear: - Memory bandwidth (can we fuse more?) - Warp divergence (can we restructure conditionals?) - Register pressure (can we reduce working set?)
36.9.1 Further Optimizations
1. Fuse with preceding/following operations:
# Instead of: y = mamba_scan(x); z = linear(y)
# Fuse: z = mamba_scan_and_linear(x, W)2. Use tensor cores for the state update: If d_state is large enough (e.g., 64), tensor cores can accelerate the matrix operations.
3. Optimize memory layout: Transpose for coalesced access patterns on the critical path.
36.10 Problem 2: Streaming Top-K (Data Analytics)
Domain: You’re building a real-time analytics system. Billions of events per hour, and you need the top-1000 events by score — continuously updated.
36.10.1 The Naive Approach
def top_k_naive(events, k=1000):
"""Sort everything, take top K."""
return sorted(events, key=lambda e: e.score, reverse=True)[:k]For 1 billion events: sorting is O(n log n) ≈ 30 billion comparisons. At 1 GHz, that’s 30 seconds. You need answers in milliseconds.
36.10.2 Property Audit
| Property | Applies? | How? |
|---|---|---|
| Associativity | Yes | top-K of top-K chunks = global top-K |
| Separability | No | No factorization structure |
| Sparsity | Partial | Most elements are irrelevant (not top-K) |
| Locality | Partial | Streaming access pattern |
| Redundancy | Yes | We don’t need exact order, just the top K |
| Symmetry | No | No invariance structure |
36.10.3 The Key Insight: Associativity of Top-K
If you split data into chunks and take the top-K of each chunk, then the top-K of all the chunk-level top-Ks is the global top-K. This is because top-K is a selection that preserves an associative merge operation:
def merge_top_k(top_k_a, top_k_b, k):
"""Merge two top-K sets. Associative!"""
combined = top_k_a + top_k_b # 2K elements
combined.sort(reverse=True) # Sort 2K elements
return combined[:k] # Take top K
# (merge(A, B), C) == merge(A, merge(B, C)) ✓ Associative!36.10.4 The Optimization
def streaming_top_k(event_stream, k=1000, chunk_size=100_000):
"""Process billion events with O(K) memory."""
import heapq
# Min-heap of size K: tracks the K largest seen so far
heap = []
for event in event_stream:
if len(heap) < k:
heapq.heappush(heap, (event.score, event))
elif event.score > heap[0][0]: # Bigger than smallest in top-K?
heapq.heapreplace(heap, (event.score, event))
return sorted(heap, reverse=True)Performance: O(n log K) time, O(K) memory. For n=1B, K=1000: ~10 billion ops vs 30 billion for sorting. But the real win is streaming: we never materialize the full dataset.
Parallelization (from associativity): Each core maintains its own top-K heap. Merge heaps pairwise in O(K log K) — the merge is associative.
Redundancy exploitation: We can use approximate top-K (e.g., Count-Min Sketch) when exact ranking isn’t needed, trading accuracy for 10-100× speed.
36.10.5 What the Framework Revealed
Without the framework, you might reach for “use a database” or “sort in MapReduce.” The property audit reveals that associativity (of the merge operation) enables streaming and parallelization, while redundancy (most events are irrelevant) means we only need O(K) memory. This is the same associativity pattern as FlashAttention’s online softmax — different domain, same algebra.
36.11 Problem 3: N-Body Simulation (Scientific Computing)
Domain: Simulate gravitational interactions between N = 1,000,000 particles. Each particle exerts force on every other. The naive approach: O(N²) = 10¹² force calculations per timestep.
36.11.1 The Naive Approach
def nbody_naive(positions, masses):
"""O(N²) all-pairs force calculation."""
N = len(positions)
forces = np.zeros_like(positions)
for i in range(N):
for j in range(N):
if i != j:
r = positions[j] - positions[i]
dist = np.linalg.norm(r)
forces[i] += masses[j] * r / (dist**3 + 1e-10)
return forcesAt 1M particles: 10¹² operations per step. Even at 10 TFLOPS (GPU), that’s 100 seconds per timestep. Simulations need thousands of steps. Intractable.
36.11.2 Property Audit
| Property | Applies? | How? |
|---|---|---|
| Associativity | Yes | Force is a sum — can be tree-reduced |
| Separability | Yes | Far-field interactions are low-rank |
| Sparsity | Partial | Near-field forces dominate; far-field is “smooth” |
| Locality | Yes | Nearby particles interact strongly; distant ones weakly |
| Redundancy | Yes | Far-field forces can be approximated |
| Symmetry | Yes | Newton’s third law: F(i→j) = -F(j→i) |
Five out of six properties apply! This is a property-rich problem.
36.11.3 The Key Insight: Separability + Locality
The gravitational potential from a distant cluster of particles can be approximated by a multipole expansion — a low-rank approximation (separability!) that’s accurate when the cluster is far away (locality!).
Close particles (distance < threshold):
Compute exact O(1) interaction per pair
Distant cluster of M particles:
Approximate entire cluster as one "super-particle"
(multipole expansion: monopole + dipole + quadrupole + ...)
Cost: O(p) per interaction instead of O(M), where p = expansion order
This is the Barnes-Hut algorithm (using a tree) or the Fast Multipole Method (FMM) using a more sophisticated hierarchy:
def nbody_fast(positions, masses, theta=0.5):
"""O(N log N) via Barnes-Hut tree."""
# 1. Build spatial tree (octree in 3D)
tree = build_octree(positions, masses) # O(N log N)
# 2. For each particle, walk the tree
forces = np.zeros_like(positions)
for i in range(len(positions)):
forces[i] = tree_force(tree, positions[i], theta)
return forces
def tree_force(node, pos, theta):
"""Compute force from tree node on particle at pos."""
r = node.center_of_mass - pos
dist = np.linalg.norm(r)
if node.is_leaf or (node.size / dist < theta):
# FAR ENOUGH: use multipole approximation (separability!)
return node.total_mass * r / (dist**3 + 1e-10)
else:
# TOO CLOSE: recurse into children (locality!)
return sum(tree_force(child, pos, theta) for child in node.children)Complexity: O(N log N) — from O(N²). For N=1M, that’s 20M operations instead of 10¹² — a 50,000× speedup.
36.11.4 What the Framework Revealed
This problem uses four properties simultaneously:
- Separability (Tier 1): Far-field interactions are low-rank → multipole expansion
- Locality (Tier 2): Near-field needs exact computation, far-field allows approximation → tree structure
- Symmetry (Tier 2): Newton’s third law halves the work → F(i→j) = -F(j→i)
- Associativity (Tier 1): Force summation can be parallelized across tree branches
The FMM is regularly cited as one of the most important algorithms of the 20th century. It wasn’t discovered by trying tricks — it was discovered by recognizing that gravitational interaction has separable structure at long range. The same property that enables LoRA (low-rank weight updates) enables the Fast Multipole Method (low-rank force approximation). The algebra is the same; the domain is different.
This is what the book’s thesis predicts: if you understand the property, you can transfer the technique across domains. A researcher who understands separability in the context of LoRA has the conceptual tools to understand FMM — and vice versa.
36.12 The Methodology in Review
Let’s trace what we did across all three problems:
| Problem | Bottleneck | Key Property | Technique | Speedup |
|---|---|---|---|---|
| Mamba scan | Sequential loop | Associativity | Parallel scan | ~15× |
| Streaming top-K | O(n log n) sort | Associativity + Redundancy | Streaming heap + parallel merge | ~3-10× |
| N-body | O(N²) all-pairs | Separability + Locality | Multipole expansion / tree | ~50,000× |
The framework guided us differently for each problem:
- Mamba: Associativity was the single key enabler; other properties improved constants
- Top-K: Associativity enabled parallelism; redundancy enabled streaming (don’t store everything)
- N-body: The richest case — separability, locality, symmetry, and associativity all contributed, with separability providing the algorithmic breakthrough
36.13 Generalizing the Approach
This process works for any optimization problem:
┌─────────────────────────────────────────────────────────────┐
│ THE OPTIMIZATION LOOP │
├─────────────────────────────────────────────────────────────┤
│ │
│ 1. UNDERSTAND │
│ ├── What does the computation do? │
│ ├── What are the data dependencies? │
│ └── What are the dimensions/sizes? │
│ │
│ 2. MEASURE │
│ ├── Profile end-to-end │
│ ├── Identify the bottleneck │
│ └── Is it compute-bound or memory-bound? │
│ │
│ 3. ANALYZE PROPERTIES │
│ ├── Associativity? → Chunking, parallelization │
│ ├── Separability? → Factorization, low-rank │
│ ├── Locality? → Tiling, fusion │
│ ├── Sparsity? → Skip computation │
│ ├── Redundancy? → Quantization, caching │
│ └── Symmetry? → Weight sharing, FFT │
│ │
│ 4. DESIGN │
│ ├── Which property addresses the bottleneck? │
│ ├── What transformation exploits it? │
│ └── How do we implement it efficiently? │
│ │
│ 5. IMPLEMENT │
│ ├── Start simple, verify correctness │
│ ├── Optimize incrementally │
│ └── Use appropriate tools (Triton, CUDA, torch.compile) │
│ │
│ 6. MEASURE AGAIN │
│ ├── Did we improve? │
│ ├── What's the new bottleneck? │
│ └── Is there more to gain? │
│ │
│ 7. ITERATE │
│ └── Return to step 3 with new bottleneck │
│ │
└─────────────────────────────────────────────────────────────┘
36.14 Key Takeaways
Understand before optimizing: Know what the computation does and why.
Measure to find the bottleneck: Don’t guess—profile.
Check all six properties systematically: Which ones apply to your problem?
The property tells you the technique: Associativity → parallel scan. Locality → tiling. The framework guides the solution.
Implement incrementally: Start simple, verify, then optimize.
Iterate: The first optimization reveals the next bottleneck.
Know when to stop: Diminishing returns are real. Ship when good enough.
36.15 Exercises
Property Identification: Attention is O(n²). Which properties would you exploit to make it O(n)? (Hint: sparse attention patterns exploit which property?)
Derivation: Given that addition is associative, derive the parallel prefix sum algorithm. What’s the depth (number of sequential steps) for n elements?
Application: You have a model with many small matrix multiplies in sequence. Which property suggests an optimization? What would you try?
Measurement: Profile a computation in your own codebase. What’s the bottleneck? Which property might help?
Exercise 1: Property Identification
To reduce attention from O(n²) to O(n), you can exploit:
- Sparsity - Sparse attention patterns (e.g., local windows, strided patterns, Big Bird):
- Instead of attending to all n tokens, each token attends to only O(1) or O(log n) others
- Examples: Longformer (local + global), Sparse Transformer (strided patterns)
- Separability (Low-rank approximation):
- Linformer projects keys/values to lower dimension: O(n × d × k) where k << n
- Performer uses random features to approximate softmax: O(n × d)
- Locality:
- FlashAttention exploits locality by tiling Q, K, V into blocks that fit in SRAM
- Doesn’t change asymptotic complexity but dramatically improves constants
- Associativity (for state-space alternatives):
- Mamba replaces attention with a linear recurrence that can be parallelized
- The recurrence is associative → parallel scan → O(n log n) with high parallelism
The most successful approaches combine multiple properties: FlashAttention uses locality + associativity of softmax normalization; Mamba uses associativity + separability (low-rank state).
Exercise 2: Parallel Prefix Sum Derivation
Given n elements \([a_0, a_1, ..., a_{n-1}]\), compute prefix sums \([s_0, s_1, ..., s_{n-1}]\) where \(s_i = \sum_{j=0}^{i} a_j\).
Sequential: \(s_0 = a_0\), \(s_i = s_{i-1} + a_i\) — O(n) sequential steps.
Parallel (Blelloch scan):
Up-sweep (reduce): Build a tree of partial sums
Level 0: a₀ a₁ a₂ a₃ a₄ a₅ a₆ a₇
Level 1: a₀ (a₀+a₁) a₂ (a₂+a₃) a₄ (a₄+a₅) a₆ (a₆+a₇)
Level 2: a₀ (a₀+a₁) a₂ (a₀..a₃) a₄ (a₄+a₅) a₆ (a₄..a₇)
Level 3: a₀ (a₀+a₁) a₂ (a₀..a₃) a₄ (a₄+a₅) a₆ (a₀..a₇)
Down-sweep: Distribute results back down
Replace root with 0, propagate: left child gets parent, right gets parent + old left
Depth: \(2 \log_2(n)\) parallel steps (log n for up-sweep, log n for down-sweep)
Work: O(n) total operations, but only O(log n) sequential steps
Exercise 3: Many Small Matrix Multiplies
Property: Associativity (and locality)
Optimizations to try:
Batched GEMM: Combine small matmuls into a single batched operation
# Instead of: for W in weights: x = x @ W # Use batched: x = torch.bmm(x_batched, W_stacked)Fusion: Fuse the sequence of matmuls with preceding/following operations
# torch.compile can fuse: @torch.compile def fused_mlp(x, W1, W2, W3): return gelu(x @ W1) @ W2 @ W3Pre-multiply (if shapes allow): Use associativity to reduce runtime work
# If W1, W2 are constant: W_combined = W1 @ W2 # Precompute once y = x @ W_combined # Apply combinedTensor cores: Ensure dimensions are multiples of 8/16 for efficient tensor core usage
The key insight: many small operations create kernel launch overhead and poor GPU utilization. Batching exploits parallelism; fusion exploits locality.
Exercise 4: Profiling Your Codebase
This is an open-ended exercise. Here’s a framework for approaching it:
Step 1: Profile end-to-end
# PyTorch profiler
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA],
record_shapes=True,
with_stack=True
) as prof:
your_function()
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))Step 2: Identify bottleneck type - If top operations are matmul/conv: compute-bound → check if you’re hitting roofline - If top time is in memory ops or small kernels: memory-bound → consider fusion - If many kernel launches: launch overhead → consider batching or fusion
Step 3: Match property to bottleneck
| Bottleneck | Property | Technique |
|---|---|---|
| Sequential loop | Associativity | Parallel scan |
| Large tensors | Separability | Low-rank factorization |
| Many small ops | Locality | Fusion |
| Dense computation | Sparsity | Pruning, skip connections |
| Redundant computation | Redundancy | Caching, quantization |
| Repeated structure | Symmetry | Weight sharing, FFT |
Step 4: Implement and measure Always verify speedup with wall-clock time, not just FLOPS or theoretical analysis.
36.16 The Lesson
Three problems. Three domains. The same six-property framework identified the key optimization in each case.
This isn’t coincidence — it’s the book’s thesis in action. Mathematical properties are universal. Associativity enables parallel scan in ML recurrences and streaming aggregation in analytics. Separability enables LoRA in fine-tuning and multipole expansion in physics. The algebra doesn’t know what domain it’s in.
The framework isn’t just theoretical — it’s a practical methodology for deriving optimizations. The properties are your vocabulary; the methodology is your grammar.
Now go optimize something.