32  Fusion

When Combining Operations Beats Separating Them

Every memory round-trip has a cost. If you write to memory just to read it back, you’ve paid twice for nothing.

Fusion eliminates these round-trips by combining operations into a single pass.

Fusion is primarily a compiler/runtime optimization; later tooling chapters show how modern stacks implement it.

32.1 The Hidden Cost of Abstraction

Modern deep learning frameworks provide beautiful abstractions:

def layer(x):
    x = linear(x)
    x = relu(x)
    x = dropout(x)
    return x

Clean, readable, modular. Each operation is a separate function.

But at runtime, this beautiful code has an ugly secret:

Kernel 1 (linear):  Read x from HBM → Compute Wx+b → Write result to HBM
Kernel 2 (relu):    Read from HBM → max(0, x) → Write to HBM
Kernel 3 (dropout): Read from HBM → mask → Write to HBM

Three round-trips to HBM. The cost is bytes ÷ bandwidth: moving a 1 GB tensor at 3 TB/s is ~0.33 ms per pass. These add up.

The computation is trivial—multiply, compare, mask. The memory traffic is expensive.

32.2 What Is Fusion?

Fusion combines multiple operations into a single kernel:

Fused kernel: Read x from HBM → Wx+b → relu → dropout → Write to HBM

One read, one write. The intermediates stay in registers or shared memory, never touching HBM.

The benefit scales with arithmetic intensity. Operations with low arithmetic intensity (element-wise, small reductions) benefit most because memory traffic dominates their runtime.

32.3 When Fusion Helps

Not all operations benefit equally from fusion.

32.3.1 High Benefit: Element-wise Chains

y = torch.sigmoid(x) * torch.tanh(x + 1)

Without fusion: 4 memory round-trips (read x twice, write twice) With fusion: 1 read, 1 write

Speedup: often 3-4× or more.

32.3.2 High Benefit: Small Reductions Following Compute

loss = F.cross_entropy(logits, labels)

Cross-entropy computes softmax then log then reduction. Fusing prevents materializing the softmax probabilities.

32.3.3 Moderate Benefit: LayerNorm and Softmax

x = F.layer_norm(x, normalized_shape)

LayerNorm computes mean, variance, normalization, and affine transform. Fusing keeps statistics in registers.

32.3.4 Low Benefit: Large Matrix Multiplies

y = x @ W

Matrix multiply already has high arithmetic intensity. It’s compute-bound, not memory-bound. Fusion helps the operations around matmul, not matmul itself.

32.4 The Fusion Hierarchy

Fusion can happen at different levels:

32.4.1 Level 1: Manual Fused Kernels

Hand-written CUDA kernels combining specific operations:

__global__ void fused_linear_relu(float* output, float* input,
                                   float* weight, float* bias, int n) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < n) {
        float val = 0;
        for (int j = 0; j < d; j++) {
            val += input[j] * weight[idx * d + j];
        }
        val += bias[idx];
        output[idx] = val > 0 ? val : 0;  // ReLU fused in
    }
}

Pros: Maximum control, maximum performance. Cons: Combinatorial explosion of kernel variants. Can’t fuse arbitrary combinations.

32.4.2 Level 2: Pattern-Based Fusion

Frameworks recognize common patterns and use pre-fused kernels:

# PyTorch recognizes this pattern
x = F.linear(x, weight, bias)
x = F.relu(x)
# May be fused automatically

Pros: Transparent to user, no code changes. Cons: Limited to recognized patterns.

32.4.3 Level 3: JIT Compilation (torch.compile)

Compilers analyze the computation graph and fuse automatically:

@torch.compile
def layer(x):
    x = F.linear(x, weight, bias)
    x = F.relu(x)
    x = F.dropout(x, p=0.1)
    return x

The compiler traces execution, builds a graph, and generates fused kernels.

Pros: Works with arbitrary code, finds fusion opportunities automatically. Cons: Compilation overhead, potential for unexpected behavior.

32.5 Automatic Fusion with torch.compile

PyTorch 2.0 introduced torch.compile, which automatically identifies and applies fusion opportunities:

@torch.compile
def my_function(x, y):
    z = x + y          # These element-wise ops
    z = torch.relu(z)  # get fused into a single kernel
    return z

torch.compile uses TorchDynamo to trace Python bytecode into a graph, then TorchInductor generates fused Triton kernels. It typically fuses element-wise chains, normalization layers, and pointwise ops following matmul—but leaves large matrix multiplies to cuBLAS.

For a comprehensive treatment of torch.compile—including graph breaks, compilation modes, dynamic shapes, and integration patterns—see the torch.compile Deep Dive.

32.6 The Limits of Fusion

Fusion isn’t always beneficial.

Register pressure: Each additional operation in a fused kernel requires registers. Very long fusion chains may spill to local memory, negating benefits.

Compilation overhead: torch.compile has upfront cost (100ms to several seconds for first call). For short-running or highly dynamic code, this may not pay off.

Dynamic shapes: Varying tensor shapes can trigger recompilation. Use torch.compile(dynamic=True) or padding to mitigate.

32.7 Fusion and the Memory Hierarchy

Fusion is fundamentally about the memory hierarchy. Let’s trace an example:

32.7.1 Without Fusion

def unfused_attention_output(attn_weights, v, residual):
    # Kernel 1: matmul
    attn_out = attn_weights @ v  # Read attn_weights, v; Write attn_out to HBM

    # Kernel 2: residual add
    out = attn_out + residual    # Read attn_out, residual; Write out to HBM

    # Kernel 3: layer norm
    out = F.layer_norm(out, ...)  # Read out; Write normalized to HBM

    return out

Memory traffic: attn_out written once, read once (2 passes). Same for intermediate.

32.7.2 With Fusion

@torch.compile
def fused_attention_output(attn_weights, v, residual):
    attn_out = attn_weights @ v
    out = attn_out + residual
    out = F.layer_norm(out, ...)
    return out

The compiler may fuse the residual add and layer norm:

Kernel 1: matmul (uses optimized cuBLAS, writes to HBM)
Kernel 2 (fused): Read matmul output + residual → add → layernorm → Write final output

One fewer round-trip.

32.8 Practical Guidance

32.8.1 When to Use torch.compile

# Good candidates for torch.compile:

# 1. Inference on fixed shapes
@torch.compile(mode="reduce-overhead")
def inference(x):
    return model(x)

# 2. Training loops with stable shapes
model = torch.compile(model)

# 3. Custom loss functions with many small ops
@torch.compile
def custom_loss(pred, target):
    # Many element-wise operations that benefit from fusion
    ...

32.8.2 When to Avoid

# Bad candidates:

# 1. Highly dynamic control flow
def dynamic_model(x):
    if x.sum() > 0:  # Data-dependent branch
        return path_a(x)
    else:
        return path_b(x)

# 2. Very short functions (compilation cost > runtime benefit)
@torch.compile
def tiny(x):
    return x + 1  # Probably not worth it

# 3. Development/debugging (compilation hides errors)

32.8.3 Measuring Fusion Impact

import torch
import time

def benchmark(fn, x, n_warmup=10, n_iter=100):
    for _ in range(n_warmup):
        fn(x)
    torch.cuda.synchronize()

    start = time.perf_counter()
    for _ in range(n_iter):
        fn(x)
    torch.cuda.synchronize()

    return (time.perf_counter() - start) / n_iter * 1000  # ms

x = torch.randn(4096, 4096, device='cuda')

def chain(x):
    x = x + 1
    x = x * 2
    x = torch.relu(x)
    x = x ** 2
    return x

compiled = torch.compile(chain)

print(f"Eager:    {benchmark(chain, x):.3f} ms")
print(f"Compiled: {benchmark(compiled, x):.3f} ms")
# Typical result: 2-4× speedup

32.9 Fusion in FlashAttention

FlashAttention is the ultimate fusion example. It fuses the entire attention computation:

Standard:
  S = Q @ K.T    → Write S to HBM (n² elements)
  P = softmax(S) → Write P to HBM (n² elements)
  O = P @ V      → Write O to HBM (n×d elements)

FlashAttention:
  [Fused kernel]: Read Q, K, V → Compute everything in SRAM → Write O only

The “fusion” here is extreme: all intermediates stay in SRAM, never touching HBM. This is why FlashAttention is faster despite doing more compute (the streaming softmax corrections).

32.10 Connections

Bandwidth: Fusion reduces memory traffic, addressing the bandwidth bottleneck.

Locality: Fusion improves temporal locality—data is reused before eviction from registers/cache.

FlashAttention: The ultimate example of fusion principles applied to attention.

32.11 Key Takeaways

  1. Fusion eliminates memory round-trips: Every intermediate written to HBM and read back is a cost that fusion can avoid.

  2. Low arithmetic intensity operations benefit most: Element-wise ops, small reductions, normalization layers.

  3. torch.compile automates fusion: For most code, @torch.compile is the easiest way to get fusion benefits.

  4. Manual fusion still matters: For maximum performance (like FlashAttention), hand-written fused kernels remain important.

  5. Fusion has limits: Register pressure, compilation overhead, and dynamic shapes all constrain what can be fused.

NoteTry It Yourself

The accompanying notebook lets you:

  • Measure fusion speedups on element-wise chains
  • Compare eager vs. compiled execution
  • Visualize what torch.compile fuses
  • Profile memory traffic with and without fusion

Open In Colab