35  Capstone: Optimizing a Novel Problem

Applying the Framework from Scratch

You’ve learned the properties. You’ve studied the case studies.

Now we face a problem you haven’t seen before—and derive the optimization from first principles.

35.1 The Challenge

This chapter applies everything we’ve learned to a novel problem: optimizing Mamba-style state space models for inference.

We’ll walk through the complete process: 1. Understand the computation 2. Profile to find bottlenecks 3. Identify applicable properties 4. Derive optimizations 5. Implement and measure 6. Iterate

This mirrors how you’ll approach new optimization problems in practice.

35.2 Understanding Mamba

Mamba is a state space model (SSM) that offers an alternative to attention. Its key advantage: linear complexity in sequence length, compared to attention’s quadratic.

35.2.1 The Core Recurrence

At its heart, Mamba computes a recurrence:

def mamba_scan_naive(x, A, B, C, delta):
    """
    x: (batch, seq_len, d_model) - input sequence
    A: (d_model, d_state) - state transition (discretized)
    B: (batch, seq_len, d_state) - input projection
    C: (batch, seq_len, d_state) - output projection
    delta: (batch, seq_len, d_model) - time step

    Returns: (batch, seq_len, d_model) - output
    """
    batch, seq_len, d_model = x.shape
    d_state = A.shape[1]

    # Discretize A using delta
    # A_bar = exp(delta * A)

    h = torch.zeros(batch, d_model, d_state)  # Hidden state
    outputs = []

    for t in range(seq_len):
        # Update hidden state: h = A_bar * h + B * x
        A_bar = torch.exp(delta[:, t, :, None] * A)  # (batch, d_model, d_state)
        h = A_bar * h + B[:, t, None, :] * x[:, t, :, None]

        # Output: y = C * h
        y = (C[:, t, None, :] * h).sum(dim=-1)  # (batch, d_model)
        outputs.append(y)

    return torch.stack(outputs, dim=1)

35.2.2 The Problem

This sequential scan is slow. For a sequence of length N: - N sequential steps (no parallelism) - Each step has O(d_model × d_state) operations - Total: O(N × d_model × d_state)

The sequential nature is the bottleneck. Can we do better?

35.3 Step 1: Profile the Baseline

Before optimizing, measure:

import torch
import time

def benchmark_mamba(seq_len, d_model=2048, d_state=16, batch=1):
    x = torch.randn(batch, seq_len, d_model, device='cuda')
    A = torch.randn(d_model, d_state, device='cuda')
    B = torch.randn(batch, seq_len, d_state, device='cuda')
    C = torch.randn(batch, seq_len, d_state, device='cuda')
    delta = torch.randn(batch, seq_len, d_model, device='cuda').abs()

    # Warmup
    for _ in range(3):
        _ = mamba_scan_naive(x, A, B, C, delta)
    torch.cuda.synchronize()

    # Benchmark
    start = time.perf_counter()
    for _ in range(10):
        _ = mamba_scan_naive(x, A, B, C, delta)
    torch.cuda.synchronize()

    ms = (time.perf_counter() - start) / 10 * 1000
    return ms

# Profile across sequence lengths
for seq_len in [256, 512, 1024, 2048, 4096]:
    ms = benchmark_mamba(seq_len)
    tokens_per_sec = seq_len / (ms / 1000)
    print(f"seq_len={seq_len:4d}: {ms:7.2f}ms ({tokens_per_sec:,.0f} tok/s)")

Typical output:

seq_len= 256:   12.34ms (20,746 tok/s)
seq_len= 512:   24.56ms (20,847 tok/s)
seq_len=1024:   49.12ms (20,847 tok/s)
seq_len=2048:   98.23ms (20,847 tok/s)
seq_len=4096:  196.45ms (20,847 tok/s)

Linear scaling with sequence length—as expected for sequential scan. But 20K tokens/sec is slow for inference.

35.3.1 Identifying the Bottleneck

Use a profiler to see where time goes:

with torch.profiler.profile(
    activities=[torch.profiler.ProfilerActivity.CUDA],
    record_shapes=True
) as prof:
    mamba_scan_naive(x, A, B, C, delta)

print(prof.key_averages().table(sort_by="cuda_time_total"))

The profile reveals: - Many small kernel launches (one per timestep) - Low GPU utilization - Memory-bound operations

Diagnosis: The sequential loop prevents parallelism. Each timestep launches new kernels, creating overhead.

35.4 Step 2: Identify Applicable Properties

Let’s check each property from our framework:

35.4.1 Associativity ✓

The recurrence \(h_t = A h_{t-1} + B x_t\) is associative when viewed as matrix operations.

For linear recurrences, we can use the parallel scan algorithm:

Sequential:  h₁ → h₂ → h₃ → h₄ → h₅ → h₆ → h₇ → h₈
             (8 sequential steps)

Parallel scan:
  Step 1: (h₁,h₂) (h₃,h₄) (h₅,h₆) (h₇,h₈)  [4 parallel]
  Step 2: (h₁..h₄) (h₅..h₈)                  [2 parallel]
  Step 3: (h₁..h₈)                           [1]
             (3 = log₂(8) steps)

This is the same associativity that enables parallel prefix sum!

35.4.2 Locality ✓

The state \(h\) is small (d_model × d_state), but we’re streaming through a long sequence. We can tile the computation to keep working sets in fast memory.

35.4.3 Separability (Partial)

The state dimension d_state is chosen to be small (16-64). This is a design choice exploiting separability—the “effective” dynamics live in a low-dimensional space.

35.4.4 Sparsity ✗

The state transitions are dense. Sparsity doesn’t obviously apply here.

35.4.5 Redundancy (Partial)

During generation, we cache the hidden state h. This is KV-cache style redundancy exploitation.

35.5 Step 3: Derive the Optimization

The key insight is associativity enables parallel scan.

35.5.1 The Parallel Scan Algorithm

Rewrite the recurrence in matrix form:

\[ \begin{pmatrix} h_t \\ 1 \end{pmatrix} = \begin{pmatrix} A_t & B_t x_t \\ 0 & 1 \end{pmatrix} \begin{pmatrix} h_{t-1} \\ 1 \end{pmatrix} \]

Let \(M_t = \begin{pmatrix} A_t & B_t x_t \\ 0 & 1 \end{pmatrix}\).

Then the full computation is:

\[h_N = M_N \cdot M_{N-1} \cdot ... \cdot M_1 \cdot h_0\]

Matrix multiplication is associative! We can compute this product in any order.

35.5.2 Parallel Prefix Product

def parallel_scan_mamba(x, A, B, C, delta):
    """Parallel scan version of Mamba."""
    batch, seq_len, d_model = x.shape
    d_state = A.shape[1]

    # Discretize
    A_bar = torch.exp(delta.unsqueeze(-1) * A)  # (batch, seq, d_model, d_state)
    B_x = B.unsqueeze(2) * x.unsqueeze(-1)      # (batch, seq, d_model, d_state)

    # Pack into "elements" for parallel scan
    # Each element: (A_bar_t, B_x_t)
    # Associative op: (A1, B1) ⊕ (A2, B2) = (A2 * A1, A2 * B1 + B2)

    # Use associative scan primitive
    h = associative_scan(
        lambda a, b: (b[0] * a[0], b[0] * a[1] + b[1]),
        (A_bar, B_x)
    )

    # Output projection
    y = (C.unsqueeze(2) * h).sum(dim=-1)

    return y

The associative_scan operation computes all prefix products in O(log N) parallel steps instead of O(N) sequential steps.

35.5.3 GPU-Friendly Implementation

For GPUs, we need to: 1. Tile the scan to fit in shared memory 2. Fuse the discretization and scan 3. Avoid materializing large intermediates

# Conceptual structure of optimized kernel
@triton.jit
def mamba_scan_kernel(
    x_ptr, A_ptr, B_ptr, C_ptr, delta_ptr, output_ptr,
    seq_len, d_model, d_state,
    BLOCK_SIZE: tl.constexpr
):
    # Each block handles a chunk of the sequence
    block_id = tl.program_id(0)

    # Load chunk into shared memory
    # ...

    # Parallel scan within block
    # ...

    # Combine with previous blocks (inter-block scan)
    # ...

    # Write output
    # ...

35.6 Step 4: Implementation

The real Mamba implementation uses a custom CUDA kernel. Here’s a simplified version in Triton:

import triton
import triton.language as tl

@triton.jit
def mamba_chunk_scan_kernel(
    # Input pointers
    x_ptr, A_ptr, B_ptr, C_ptr, delta_ptr,
    # Output pointer
    y_ptr,
    # Dimensions
    batch, seq_len, d_model, d_state,
    # Strides
    stride_x_batch, stride_x_seq, stride_x_d,
    # Block size
    BLOCK_SEQ: tl.constexpr,
    BLOCK_D: tl.constexpr,
):
    """Chunked parallel scan for Mamba."""
    # Get program IDs
    pid_batch = tl.program_id(0)
    pid_d = tl.program_id(1)

    # Offset calculations
    d_offset = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
    d_mask = d_offset < d_model

    # Initialize hidden state
    h = tl.zeros((BLOCK_D, d_state), dtype=tl.float32)

    # Process in chunks
    for chunk_start in range(0, seq_len, BLOCK_SEQ):
        chunk_end = min(chunk_start + BLOCK_SEQ, seq_len)

        # Load chunk (vectorized)
        for t in range(chunk_start, chunk_end):
            # Load inputs for this timestep
            x_offset = pid_batch * stride_x_batch + t * stride_x_seq + d_offset
            x_t = tl.load(x_ptr + x_offset, mask=d_mask)

            # Load delta, B, C
            delta_t = tl.load(delta_ptr + ...)
            B_t = tl.load(B_ptr + ...)
            C_t = tl.load(C_ptr + ...)

            # Discretize A
            A_bar = tl.exp(delta_t[:, None] * tl.load(A_ptr + ...))

            # Update hidden state
            h = A_bar * h + B_t[None, :] * x_t[:, None]

            # Compute output
            y_t = tl.sum(C_t[None, :] * h, axis=1)

            # Store output
            tl.store(y_ptr + ..., y_t, mask=d_mask)

# Wrapper
def mamba_scan_optimized(x, A, B, C, delta):
    batch, seq_len, d_model = x.shape
    d_state = A.shape[1]

    y = torch.empty_like(x)

    BLOCK_SEQ = 64
    BLOCK_D = 64

    grid = (batch, triton.cdiv(d_model, BLOCK_D))

    mamba_chunk_scan_kernel[grid](
        x, A, B, C, delta, y,
        batch, seq_len, d_model, d_state,
        x.stride(0), x.stride(1), x.stride(2),
        BLOCK_SEQ, BLOCK_D
    )

    return y

35.7 Step 5: Measure Improvement

def compare_implementations():
    seq_len = 2048
    d_model = 2048
    d_state = 16
    batch = 1

    x = torch.randn(batch, seq_len, d_model, device='cuda')
    A = torch.randn(d_model, d_state, device='cuda')
    B = torch.randn(batch, seq_len, d_state, device='cuda')
    C = torch.randn(batch, seq_len, d_state, device='cuda')
    delta = torch.randn(batch, seq_len, d_model, device='cuda').abs()

    # Benchmark naive
    naive_ms = benchmark(lambda: mamba_scan_naive(x, A, B, C, delta))

    # Benchmark optimized
    opt_ms = benchmark(lambda: mamba_scan_optimized(x, A, B, C, delta))

    print(f"Naive:     {naive_ms:.2f}ms")
    print(f"Optimized: {opt_ms:.2f}ms")
    print(f"Speedup:   {naive_ms/opt_ms:.1f}x")

compare_implementations()

Expected results:

Naive:     98.23ms
Optimized:  6.42ms
Speedup:   15.3x

35.8 Step 6: Iterate

The first optimization rarely captures all the performance. Let’s profile again:

with torch.profiler.profile(...) as prof:
    mamba_scan_optimized(x, A, B, C, delta)

New bottlenecks might appear: - Memory bandwidth (can we fuse more?) - Warp divergence (can we restructure conditionals?) - Register pressure (can we reduce working set?)

35.8.1 Further Optimizations

1. Fuse with preceding/following operations:

# Instead of: y = mamba_scan(x); z = linear(y)
# Fuse: z = mamba_scan_and_linear(x, W)

2. Use tensor cores for the state update: If d_state is large enough (e.g., 64), tensor cores can accelerate the matrix operations.

3. Optimize memory layout: Transpose for coalesced access patterns on the critical path.

35.9 The Methodology in Review

Let’s trace what we did:

Step Action Properties Used
1 Profile baseline -
2 Identify bottleneck: sequential loop -
3 Check associativity: recurrence is associative Associativity
4 Apply parallel scan Associativity
5 Tile for memory Locality
6 Fuse operations Locality
7 Measure improvement -
8 Iterate -

The framework guided us: - We checked each property systematically - Associativity was the key enabler - Locality improved constant factors - The other properties didn’t apply to this specific problem

35.10 Generalizing the Approach

This process works for any optimization problem:

┌─────────────────────────────────────────────────────────────┐
│                    THE OPTIMIZATION LOOP                    │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  1. UNDERSTAND                                              │
│     ├── What does the computation do?                       │
│     ├── What are the data dependencies?                     │
│     └── What are the dimensions/sizes?                      │
│                                                             │
│  2. MEASURE                                                 │
│     ├── Profile end-to-end                                  │
│     ├── Identify the bottleneck                             │
│     └── Is it compute-bound or memory-bound?                │
│                                                             │
│  3. ANALYZE PROPERTIES                                      │
│     ├── Associativity? → Chunking, parallelization          │
│     ├── Separability?  → Factorization, low-rank            │
│     ├── Locality?      → Tiling, fusion                     │
│     ├── Sparsity?      → Skip computation                   │
│     ├── Redundancy?    → Quantization, caching              │
│     └── Symmetry?      → Weight sharing, FFT                │
│                                                             │
│  4. DESIGN                                                  │
│     ├── Which property addresses the bottleneck?            │
│     ├── What transformation exploits it?                    │
│     └── How do we implement it efficiently?                 │
│                                                             │
│  5. IMPLEMENT                                               │
│     ├── Start simple, verify correctness                    │
│     ├── Optimize incrementally                              │
│     └── Use appropriate tools (Triton, CUDA, torch.compile) │
│                                                             │
│  6. MEASURE AGAIN                                           │
│     ├── Did we improve?                                     │
│     ├── What's the new bottleneck?                          │
│     └── Is there more to gain?                              │
│                                                             │
│  7. ITERATE                                                 │
│     └── Return to step 3 with new bottleneck                │
│                                                             │
└─────────────────────────────────────────────────────────────┘

35.11 Key Takeaways

TipThe Complete Method
  1. Understand before optimizing: Know what the computation does and why.

  2. Measure to find the bottleneck: Don’t guess—profile.

  3. Check all six properties systematically: Which ones apply to your problem?

  4. The property tells you the technique: Associativity → parallel scan. Locality → tiling. The framework guides the solution.

  5. Implement incrementally: Start simple, verify, then optimize.

  6. Iterate: The first optimization reveals the next bottleneck.

  7. Know when to stop: Diminishing returns are real. Ship when good enough.

35.12 Exercises

  1. Property Identification: Attention is O(n²). Which properties would you exploit to make it O(n)? (Hint: sparse attention patterns exploit which property?)

  2. Derivation: Given that addition is associative, derive the parallel prefix sum algorithm. What’s the depth (number of sequential steps) for n elements?

  3. Application: You have a model with many small matrix multiplies in sequence. Which property suggests an optimization? What would you try?

  4. Measurement: Profile a computation in your own codebase. What’s the bottleneck? Which property might help?


This capstone demonstrates that the framework isn’t just theoretical—it’s a practical methodology for deriving optimizations. The properties are your vocabulary; the methodology is your grammar.

Now go optimize something.