36  Capstone: The Framework in Your Hands

Three Problems, Six Properties, One Method

You’ve learned the properties. You’ve studied the case studies.

Now we face three problems you haven’t seen optimized before — spanning ML, systems, and scientific computing. For each, we apply the framework from first principles and see which properties unlock the optimization.

The goal isn’t to memorize these solutions. It’s to internalize the process so you can do this yourself.

36.1 The Method

Every investigation follows the same loop:

  1. Understand the computation
  2. Profile to find bottlenecks
  3. Audit properties — systematically check all six
  4. Derive the optimization from the matching property
  5. Implement and measure
  6. Iterate until satisfied

We’ll apply this to three problems of increasing breadth:

Problem Domain Primary Properties Key Insight
Mamba scan ML inference Associativity + Locality Sequential recurrence → parallel scan
Streaming top-K Data analytics Associativity + Redundancy Exact ranking from partial heaps
N-body simulation Scientific computing Separability + Locality Far-field approximation via multipole

36.2 Problem 1: Mamba Scan (ML Inference)

This problem applies everything we’ve learned to optimizing Mamba-style state space models for inference.

36.3 Understanding Mamba

Mamba [1] is a state space model (SSM) that offers an alternative to attention. Its key advantage: linear complexity in sequence length, compared to attention’s quadratic.

36.3.1 The Core Recurrence

At its heart, Mamba computes a recurrence:

def mamba_scan_naive(x, A, B, C, delta):
    """
    x: (batch, seq_len, d_model) - input sequence
    A: (d_model, d_state) - state transition (discretized)
    B: (batch, seq_len, d_state) - input projection
    C: (batch, seq_len, d_state) - output projection
    delta: (batch, seq_len, d_model) - time step

    Returns: (batch, seq_len, d_model) - output
    """
    batch, seq_len, d_model = x.shape
    d_state = A.shape[1]

    # Discretize A using delta
    # A_bar = exp(delta * A)

    h = torch.zeros(batch, d_model, d_state)  # Hidden state
    outputs = []

    for t in range(seq_len):
        # Update hidden state: h = A_bar * h + B * x
        A_bar = torch.exp(delta[:, t, :, None] * A)  # (batch, d_model, d_state)
        h = A_bar * h + B[:, t, None, :] * x[:, t, :, None]

        # Output: y = C * h
        y = (C[:, t, None, :] * h).sum(dim=-1)  # (batch, d_model)
        outputs.append(y)

    return torch.stack(outputs, dim=1)

36.3.2 The Problem

This sequential scan is slow. For a sequence of length N: - N sequential steps (no parallelism) - Each step has O(d_model × d_state) operations - Total: O(N × d_model × d_state)

The sequential nature is the bottleneck. Can we do better?

36.4 Step 1: Profile the Baseline

Before optimizing, measure:

import torch
import time

def benchmark_mamba(seq_len, d_model=2048, d_state=16, batch=1):
    x = torch.randn(batch, seq_len, d_model, device='cuda')
    A = torch.randn(d_model, d_state, device='cuda')
    B = torch.randn(batch, seq_len, d_state, device='cuda')
    C = torch.randn(batch, seq_len, d_state, device='cuda')
    delta = torch.randn(batch, seq_len, d_model, device='cuda').abs()

    # Warmup
    for _ in range(3):
        _ = mamba_scan_naive(x, A, B, C, delta)
    torch.cuda.synchronize()

    # Benchmark
    start = time.perf_counter()
    for _ in range(10):
        _ = mamba_scan_naive(x, A, B, C, delta)
    torch.cuda.synchronize()

    ms = (time.perf_counter() - start) / 10 * 1000
    return ms

# Profile across sequence lengths
for seq_len in [256, 512, 1024, 2048, 4096]:
    ms = benchmark_mamba(seq_len)
    tokens_per_sec = seq_len / (ms / 1000)
    print(f"seq_len={seq_len:4d}: {ms:7.2f}ms ({tokens_per_sec:,.0f} tok/s)")

Typical output:

seq_len= 256:   12.34ms (20,746 tok/s)
seq_len= 512:   24.56ms (20,847 tok/s)
seq_len=1024:   49.12ms (20,847 tok/s)
seq_len=2048:   98.23ms (20,847 tok/s)
seq_len=4096:  196.45ms (20,847 tok/s)

Linear scaling with sequence length—as expected for sequential scan. But 20K tokens/sec is slow for inference.

36.4.1 Identifying the Bottleneck

Use a profiler to see where time goes:

with torch.profiler.profile(
    activities=[torch.profiler.ProfilerActivity.CUDA],
    record_shapes=True
) as prof:
    mamba_scan_naive(x, A, B, C, delta)

print(prof.key_averages().table(sort_by="cuda_time_total"))

The profile reveals: - Many small kernel launches (one per timestep) - Low GPU utilization - Memory-bound operations

Diagnosis: The sequential loop prevents parallelism. Each timestep launches new kernels, creating overhead.

36.5 Step 2: Identify Applicable Properties

Let’s check each property from our framework:

36.5.1 Associativity ✓

The recurrence \(h_t = A h_{t-1} + B x_t\) is associative when viewed as matrix operations.

For linear recurrences, we can use the parallel scan algorithm:

Sequential:  h₁ → h₂ → h₃ → h₄ → h₅ → h₆ → h₇ → h₈
             (8 sequential steps)

Parallel scan:
  Step 1: (h₁,h₂) (h₃,h₄) (h₅,h₆) (h₇,h₈)  [4 parallel]
  Step 2: (h₁..h₄) (h₅..h₈)                  [2 parallel]
  Step 3: (h₁..h₈)                           [1]
             (3 = log₂(8) steps)

This is the same associativity that enables parallel prefix sum!

36.5.2 Locality ✓

The state \(h\) is small (d_model × d_state), but we’re streaming through a long sequence. We can tile the computation to keep working sets in fast memory.

36.5.3 Separability (Partial)

The state dimension d_state is chosen to be small (16-64). This is a design choice exploiting separability—the “effective” dynamics live in a low-dimensional space.

36.5.4 Sparsity ✗

The state transitions are dense. Sparsity doesn’t obviously apply here.

36.5.5 Redundancy (Partial)

During generation, we cache the hidden state h. This is KV-cache style redundancy exploitation.

36.6 Step 3: Derive the Optimization

The key insight is associativity enables parallel scan.

36.6.1 The Parallel Scan Algorithm

Rewrite the recurrence in matrix form:

\[ \begin{pmatrix} h_t \\ 1 \end{pmatrix} = \begin{pmatrix} A_t & B_t x_t \\ 0 & 1 \end{pmatrix} \begin{pmatrix} h_{t-1} \\ 1 \end{pmatrix} \]

Let \(M_t = \begin{pmatrix} A_t & B_t x_t \\ 0 & 1 \end{pmatrix}\).

Then the full computation is:

\[h_N = M_N \cdot M_{N-1} \cdot ... \cdot M_1 \cdot h_0\]

Matrix multiplication is associative! We can compute this product in any order.

36.6.2 Parallel Prefix Product

def parallel_scan_mamba(x, A, B, C, delta):
    """Parallel scan version of Mamba."""
    batch, seq_len, d_model = x.shape
    d_state = A.shape[1]

    # Discretize
    A_bar = torch.exp(delta.unsqueeze(-1) * A)  # (batch, seq, d_model, d_state)
    B_x = B.unsqueeze(2) * x.unsqueeze(-1)      # (batch, seq, d_model, d_state)

    # Pack into "elements" for parallel scan
    # Each element: (A_bar_t, B_x_t)
    # Associative op: (A1, B1) ⊕ (A2, B2) = (A2 * A1, A2 * B1 + B2)

    # Use associative scan primitive
    h = associative_scan(
        lambda a, b: (b[0] * a[0], b[0] * a[1] + b[1]),
        (A_bar, B_x)
    )

    # Output projection
    y = (C.unsqueeze(2) * h).sum(dim=-1)

    return y

The associative_scan operation computes all prefix products in O(log N) parallel steps instead of O(N) sequential steps.

36.6.3 GPU-Friendly Implementation

For GPUs, we need to: 1. Tile the scan to fit in shared memory 2. Fuse the discretization and scan 3. Avoid materializing large intermediates

# Conceptual structure of optimized kernel
@triton.jit
def mamba_scan_kernel(
    x_ptr, A_ptr, B_ptr, C_ptr, delta_ptr, output_ptr,
    seq_len, d_model, d_state,
    BLOCK_SIZE: tl.constexpr
):
    # Each block handles a chunk of the sequence
    block_id = tl.program_id(0)

    # Load chunk into shared memory
    # ...

    # Parallel scan within block
    # ...

    # Combine with previous blocks (inter-block scan)
    # ...

    # Write output
    # ...

36.7 Step 4: Implementation

The real Mamba implementation uses a custom CUDA kernel. Here’s a simplified version in Triton:

import triton
import triton.language as tl

@triton.jit
def mamba_chunk_scan_kernel(
    # Input pointers
    x_ptr, A_ptr, B_ptr, C_ptr, delta_ptr,
    # Output pointer
    y_ptr,
    # Dimensions
    batch, seq_len, d_model, d_state,
    # Strides
    stride_x_batch, stride_x_seq, stride_x_d,
    # Block size
    BLOCK_SEQ: tl.constexpr,
    BLOCK_D: tl.constexpr,
):
    """Chunked parallel scan for Mamba."""
    # Get program IDs
    pid_batch = tl.program_id(0)
    pid_d = tl.program_id(1)

    # Offset calculations
    d_offset = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
    d_mask = d_offset < d_model

    # Initialize hidden state
    h = tl.zeros((BLOCK_D, d_state), dtype=tl.float32)

    # Process in chunks
    for chunk_start in range(0, seq_len, BLOCK_SEQ):
        chunk_end = min(chunk_start + BLOCK_SEQ, seq_len)

        # Load chunk (vectorized)
        for t in range(chunk_start, chunk_end):
            # Load inputs for this timestep
            x_offset = pid_batch * stride_x_batch + t * stride_x_seq + d_offset
            x_t = tl.load(x_ptr + x_offset, mask=d_mask)

            # Load delta, B, C
            delta_t = tl.load(delta_ptr + ...)
            B_t = tl.load(B_ptr + ...)
            C_t = tl.load(C_ptr + ...)

            # Discretize A
            A_bar = tl.exp(delta_t[:, None] * tl.load(A_ptr + ...))

            # Update hidden state
            h = A_bar * h + B_t[None, :] * x_t[:, None]

            # Compute output
            y_t = tl.sum(C_t[None, :] * h, axis=1)

            # Store output
            tl.store(y_ptr + ..., y_t, mask=d_mask)

# Wrapper
def mamba_scan_optimized(x, A, B, C, delta):
    batch, seq_len, d_model = x.shape
    d_state = A.shape[1]

    y = torch.empty_like(x)

    BLOCK_SEQ = 64
    BLOCK_D = 64

    grid = (batch, triton.cdiv(d_model, BLOCK_D))

    mamba_chunk_scan_kernel[grid](
        x, A, B, C, delta, y,
        batch, seq_len, d_model, d_state,
        x.stride(0), x.stride(1), x.stride(2),
        BLOCK_SEQ, BLOCK_D
    )

    return y

36.8 Step 5: Measure Improvement

def compare_implementations():
    seq_len = 2048
    d_model = 2048
    d_state = 16
    batch = 1

    x = torch.randn(batch, seq_len, d_model, device='cuda')
    A = torch.randn(d_model, d_state, device='cuda')
    B = torch.randn(batch, seq_len, d_state, device='cuda')
    C = torch.randn(batch, seq_len, d_state, device='cuda')
    delta = torch.randn(batch, seq_len, d_model, device='cuda').abs()

    # Benchmark naive
    naive_ms = benchmark(lambda: mamba_scan_naive(x, A, B, C, delta))

    # Benchmark optimized
    opt_ms = benchmark(lambda: mamba_scan_optimized(x, A, B, C, delta))

    print(f"Naive:     {naive_ms:.2f}ms")
    print(f"Optimized: {opt_ms:.2f}ms")
    print(f"Speedup:   {naive_ms/opt_ms:.1f}x")

compare_implementations()

Expected results:

Naive:     98.23ms
Optimized:  6.42ms
Speedup:   15.3x

36.9 Step 6: Iterate

The first optimization rarely captures all the performance. Let’s profile again:

with torch.profiler.profile(...) as prof:
    mamba_scan_optimized(x, A, B, C, delta)

New bottlenecks might appear: - Memory bandwidth (can we fuse more?) - Warp divergence (can we restructure conditionals?) - Register pressure (can we reduce working set?)

36.9.1 Further Optimizations

1. Fuse with preceding/following operations:

# Instead of: y = mamba_scan(x); z = linear(y)
# Fuse: z = mamba_scan_and_linear(x, W)

2. Use tensor cores for the state update: If d_state is large enough (e.g., 64), tensor cores can accelerate the matrix operations.

3. Optimize memory layout: Transpose for coalesced access patterns on the critical path.


36.10 Problem 2: Streaming Top-K (Data Analytics)

Domain: You’re building a real-time analytics system. Billions of events per hour, and you need the top-1000 events by score — continuously updated.

36.10.1 The Naive Approach

def top_k_naive(events, k=1000):
    """Sort everything, take top K."""
    return sorted(events, key=lambda e: e.score, reverse=True)[:k]

For 1 billion events: sorting is O(n log n) ≈ 30 billion comparisons. At 1 GHz, that’s 30 seconds. You need answers in milliseconds.

36.10.2 Property Audit

Property Applies? How?
Associativity Yes top-K of top-K chunks = global top-K
Separability No No factorization structure
Sparsity Partial Most elements are irrelevant (not top-K)
Locality Partial Streaming access pattern
Redundancy Yes We don’t need exact order, just the top K
Symmetry No No invariance structure

36.10.3 The Key Insight: Associativity of Top-K

If you split data into chunks and take the top-K of each chunk, then the top-K of all the chunk-level top-Ks is the global top-K. This is because top-K is a selection that preserves an associative merge operation:

def merge_top_k(top_k_a, top_k_b, k):
    """Merge two top-K sets. Associative!"""
    combined = top_k_a + top_k_b     # 2K elements
    combined.sort(reverse=True)       # Sort 2K elements
    return combined[:k]               # Take top K

# (merge(A, B), C) == merge(A, merge(B, C))  ✓ Associative!

36.10.4 The Optimization

def streaming_top_k(event_stream, k=1000, chunk_size=100_000):
    """Process billion events with O(K) memory."""
    import heapq

    # Min-heap of size K: tracks the K largest seen so far
    heap = []

    for event in event_stream:
        if len(heap) < k:
            heapq.heappush(heap, (event.score, event))
        elif event.score > heap[0][0]:  # Bigger than smallest in top-K?
            heapq.heapreplace(heap, (event.score, event))

    return sorted(heap, reverse=True)

Performance: O(n log K) time, O(K) memory. For n=1B, K=1000: ~10 billion ops vs 30 billion for sorting. But the real win is streaming: we never materialize the full dataset.

Parallelization (from associativity): Each core maintains its own top-K heap. Merge heaps pairwise in O(K log K) — the merge is associative.

Redundancy exploitation: We can use approximate top-K (e.g., Count-Min Sketch) when exact ranking isn’t needed, trading accuracy for 10-100× speed.

36.10.5 What the Framework Revealed

Without the framework, you might reach for “use a database” or “sort in MapReduce.” The property audit reveals that associativity (of the merge operation) enables streaming and parallelization, while redundancy (most events are irrelevant) means we only need O(K) memory. This is the same associativity pattern as FlashAttention’s online softmax — different domain, same algebra.


36.11 Problem 3: N-Body Simulation (Scientific Computing)

Domain: Simulate gravitational interactions between N = 1,000,000 particles. Each particle exerts force on every other. The naive approach: O(N²) = 10¹² force calculations per timestep.

36.11.1 The Naive Approach

def nbody_naive(positions, masses):
    """O(N²) all-pairs force calculation."""
    N = len(positions)
    forces = np.zeros_like(positions)

    for i in range(N):
        for j in range(N):
            if i != j:
                r = positions[j] - positions[i]
                dist = np.linalg.norm(r)
                forces[i] += masses[j] * r / (dist**3 + 1e-10)

    return forces

At 1M particles: 10¹² operations per step. Even at 10 TFLOPS (GPU), that’s 100 seconds per timestep. Simulations need thousands of steps. Intractable.

36.11.2 Property Audit

Property Applies? How?
Associativity Yes Force is a sum — can be tree-reduced
Separability Yes Far-field interactions are low-rank
Sparsity Partial Near-field forces dominate; far-field is “smooth”
Locality Yes Nearby particles interact strongly; distant ones weakly
Redundancy Yes Far-field forces can be approximated
Symmetry Yes Newton’s third law: F(i→j) = -F(j→i)

Five out of six properties apply! This is a property-rich problem.

36.11.3 The Key Insight: Separability + Locality

The gravitational potential from a distant cluster of particles can be approximated by a multipole expansion — a low-rank approximation (separability!) that’s accurate when the cluster is far away (locality!).

Close particles (distance < threshold):
  Compute exact O(1) interaction per pair

Distant cluster of M particles:
  Approximate entire cluster as one "super-particle"
  (multipole expansion: monopole + dipole + quadrupole + ...)
  Cost: O(p) per interaction instead of O(M), where p = expansion order

This is the Barnes-Hut algorithm (using a tree) or the Fast Multipole Method (FMM) using a more sophisticated hierarchy:

def nbody_fast(positions, masses, theta=0.5):
    """O(N log N) via Barnes-Hut tree."""
    # 1. Build spatial tree (octree in 3D)
    tree = build_octree(positions, masses)  # O(N log N)

    # 2. For each particle, walk the tree
    forces = np.zeros_like(positions)
    for i in range(len(positions)):
        forces[i] = tree_force(tree, positions[i], theta)

    return forces

def tree_force(node, pos, theta):
    """Compute force from tree node on particle at pos."""
    r = node.center_of_mass - pos
    dist = np.linalg.norm(r)

    if node.is_leaf or (node.size / dist < theta):
        # FAR ENOUGH: use multipole approximation (separability!)
        return node.total_mass * r / (dist**3 + 1e-10)
    else:
        # TOO CLOSE: recurse into children (locality!)
        return sum(tree_force(child, pos, theta) for child in node.children)

Complexity: O(N log N) — from O(N²). For N=1M, that’s 20M operations instead of 10¹² — a 50,000× speedup.

36.11.4 What the Framework Revealed

This problem uses four properties simultaneously:

  • Separability (Tier 1): Far-field interactions are low-rank → multipole expansion
  • Locality (Tier 2): Near-field needs exact computation, far-field allows approximation → tree structure
  • Symmetry (Tier 2): Newton’s third law halves the work → F(i→j) = -F(j→i)
  • Associativity (Tier 1): Force summation can be parallelized across tree branches

The FMM is regularly cited as one of the most important algorithms of the 20th century. It wasn’t discovered by trying tricks — it was discovered by recognizing that gravitational interaction has separable structure at long range. The same property that enables LoRA (low-rank weight updates) enables the Fast Multipole Method (low-rank force approximation). The algebra is the same; the domain is different.

TipThe Framework’s Power: Cross-Domain Transfer

This is what the book’s thesis predicts: if you understand the property, you can transfer the technique across domains. A researcher who understands separability in the context of LoRA has the conceptual tools to understand FMM — and vice versa.


36.12 The Methodology in Review

Let’s trace what we did across all three problems:

Problem Bottleneck Key Property Technique Speedup
Mamba scan Sequential loop Associativity Parallel scan ~15×
Streaming top-K O(n log n) sort Associativity + Redundancy Streaming heap + parallel merge ~3-10×
N-body O(N²) all-pairs Separability + Locality Multipole expansion / tree ~50,000×

The framework guided us differently for each problem:

  • Mamba: Associativity was the single key enabler; other properties improved constants
  • Top-K: Associativity enabled parallelism; redundancy enabled streaming (don’t store everything)
  • N-body: The richest case — separability, locality, symmetry, and associativity all contributed, with separability providing the algorithmic breakthrough

36.13 Generalizing the Approach

This process works for any optimization problem:

┌─────────────────────────────────────────────────────────────┐
│                    THE OPTIMIZATION LOOP                    │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  1. UNDERSTAND                                              │
│     ├── What does the computation do?                       │
│     ├── What are the data dependencies?                     │
│     └── What are the dimensions/sizes?                      │
│                                                             │
│  2. MEASURE                                                 │
│     ├── Profile end-to-end                                  │
│     ├── Identify the bottleneck                             │
│     └── Is it compute-bound or memory-bound?                │
│                                                             │
│  3. ANALYZE PROPERTIES                                      │
│     ├── Associativity? → Chunking, parallelization          │
│     ├── Separability?  → Factorization, low-rank            │
│     ├── Locality?      → Tiling, fusion                     │
│     ├── Sparsity?      → Skip computation                   │
│     ├── Redundancy?    → Quantization, caching              │
│     └── Symmetry?      → Weight sharing, FFT                │
│                                                             │
│  4. DESIGN                                                  │
│     ├── Which property addresses the bottleneck?            │
│     ├── What transformation exploits it?                    │
│     └── How do we implement it efficiently?                 │
│                                                             │
│  5. IMPLEMENT                                               │
│     ├── Start simple, verify correctness                    │
│     ├── Optimize incrementally                              │
│     └── Use appropriate tools (Triton, CUDA, torch.compile) │
│                                                             │
│  6. MEASURE AGAIN                                           │
│     ├── Did we improve?                                     │
│     ├── What's the new bottleneck?                          │
│     └── Is there more to gain?                              │
│                                                             │
│  7. ITERATE                                                 │
│     └── Return to step 3 with new bottleneck                │
│                                                             │
└─────────────────────────────────────────────────────────────┘

36.14 Key Takeaways

TipThe Complete Method
  1. Understand before optimizing: Know what the computation does and why.

  2. Measure to find the bottleneck: Don’t guess—profile.

  3. Check all six properties systematically: Which ones apply to your problem?

  4. The property tells you the technique: Associativity → parallel scan. Locality → tiling. The framework guides the solution.

  5. Implement incrementally: Start simple, verify, then optimize.

  6. Iterate: The first optimization reveals the next bottleneck.

  7. Know when to stop: Diminishing returns are real. Ship when good enough.

36.15 Exercises

  1. Property Identification: Attention is O(n²). Which properties would you exploit to make it O(n)? (Hint: sparse attention patterns exploit which property?)

  2. Derivation: Given that addition is associative, derive the parallel prefix sum algorithm. What’s the depth (number of sequential steps) for n elements?

  3. Application: You have a model with many small matrix multiplies in sequence. Which property suggests an optimization? What would you try?

  4. Measurement: Profile a computation in your own codebase. What’s the bottleneck? Which property might help?

Exercise 1: Property Identification

To reduce attention from O(n²) to O(n), you can exploit:

  1. Sparsity - Sparse attention patterns (e.g., local windows, strided patterns, Big Bird):
    • Instead of attending to all n tokens, each token attends to only O(1) or O(log n) others
    • Examples: Longformer (local + global), Sparse Transformer (strided patterns)
  2. Separability (Low-rank approximation):
    • Linformer projects keys/values to lower dimension: O(n × d × k) where k << n
    • Performer uses random features to approximate softmax: O(n × d)
  3. Locality:
    • FlashAttention exploits locality by tiling Q, K, V into blocks that fit in SRAM
    • Doesn’t change asymptotic complexity but dramatically improves constants
  4. Associativity (for state-space alternatives):
    • Mamba replaces attention with a linear recurrence that can be parallelized
    • The recurrence is associative → parallel scan → O(n log n) with high parallelism

The most successful approaches combine multiple properties: FlashAttention uses locality + associativity of softmax normalization; Mamba uses associativity + separability (low-rank state).


Exercise 2: Parallel Prefix Sum Derivation

Given n elements \([a_0, a_1, ..., a_{n-1}]\), compute prefix sums \([s_0, s_1, ..., s_{n-1}]\) where \(s_i = \sum_{j=0}^{i} a_j\).

Sequential: \(s_0 = a_0\), \(s_i = s_{i-1} + a_i\) — O(n) sequential steps.

Parallel (Blelloch scan):

Up-sweep (reduce): Build a tree of partial sums

Level 0: a₀  a₁  a₂  a₃  a₄  a₅  a₆  a₇
Level 1: a₀ (a₀+a₁) a₂ (a₂+a₃) a₄ (a₄+a₅) a₆ (a₆+a₇)
Level 2: a₀ (a₀+a₁) a₂ (a₀..a₃) a₄ (a₄+a₅) a₆ (a₄..a₇)
Level 3: a₀ (a₀+a₁) a₂ (a₀..a₃) a₄ (a₄+a₅) a₆ (a₀..a₇)

Down-sweep: Distribute results back down

Replace root with 0, propagate: left child gets parent, right gets parent + old left

Depth: \(2 \log_2(n)\) parallel steps (log n for up-sweep, log n for down-sweep)

Work: O(n) total operations, but only O(log n) sequential steps


Exercise 3: Many Small Matrix Multiplies

Property: Associativity (and locality)

Optimizations to try:

  1. Batched GEMM: Combine small matmuls into a single batched operation

    # Instead of:
    for W in weights:
        x = x @ W
    
    # Use batched:
    x = torch.bmm(x_batched, W_stacked)
  2. Fusion: Fuse the sequence of matmuls with preceding/following operations

    # torch.compile can fuse:
    @torch.compile
    def fused_mlp(x, W1, W2, W3):
        return gelu(x @ W1) @ W2 @ W3
  3. Pre-multiply (if shapes allow): Use associativity to reduce runtime work

    # If W1, W2 are constant:
    W_combined = W1 @ W2  # Precompute once
    y = x @ W_combined    # Apply combined
  4. Tensor cores: Ensure dimensions are multiples of 8/16 for efficient tensor core usage

The key insight: many small operations create kernel launch overhead and poor GPU utilization. Batching exploits parallelism; fusion exploits locality.


Exercise 4: Profiling Your Codebase

This is an open-ended exercise. Here’s a framework for approaching it:

Step 1: Profile end-to-end

# PyTorch profiler
with torch.profiler.profile(
    activities=[torch.profiler.ProfilerActivity.CPU,
                torch.profiler.ProfilerActivity.CUDA],
    record_shapes=True,
    with_stack=True
) as prof:
    your_function()

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))

Step 2: Identify bottleneck type - If top operations are matmul/conv: compute-bound → check if you’re hitting roofline - If top time is in memory ops or small kernels: memory-bound → consider fusion - If many kernel launches: launch overhead → consider batching or fusion

Step 3: Match property to bottleneck

Bottleneck Property Technique
Sequential loop Associativity Parallel scan
Large tensors Separability Low-rank factorization
Many small ops Locality Fusion
Dense computation Sparsity Pruning, skip connections
Redundant computation Redundancy Caching, quantization
Repeated structure Symmetry Weight sharing, FFT

Step 4: Implement and measure Always verify speedup with wall-clock time, not just FLOPS or theoretical analysis.


36.16 The Lesson

Three problems. Three domains. The same six-property framework identified the key optimization in each case.

This isn’t coincidence — it’s the book’s thesis in action. Mathematical properties are universal. Associativity enables parallel scan in ML recurrences and streaming aggregation in analytics. Separability enables LoRA in fine-tuning and multipole expansion in physics. The algebra doesn’t know what domain it’s in.

The framework isn’t just theoretical — it’s a practical methodology for deriving optimizations. The properties are your vocabulary; the methodology is your grammar.

Now go optimize something.

[1]
A. Gu and T. Dao, “Mamba: Linear-time sequence modeling with selective state spaces,” arXiv preprint arXiv:2312.00752, 2023.