31  Fusion

When Combining Operations Beats Separating Them

Every memory round-trip has a cost. If you write to memory just to read it back, you’ve paid twice for nothing.

Fusion eliminates these round-trips by combining operations into a single pass.

31.1 The Hidden Cost of Abstraction

Modern deep learning frameworks provide beautiful abstractions:

def layer(x):
    x = linear(x)
    x = relu(x)
    x = dropout(x)
    return x

Clean, readable, modular. Each operation is a separate function.

But at runtime, this beautiful code has an ugly secret:

Kernel 1 (linear):  Read x from HBM → Compute Wx+b → Write result to HBM
Kernel 2 (relu):    Read from HBM → max(0, x) → Write to HBM
Kernel 3 (dropout): Read from HBM → mask → Write to HBM

Three round-trips to HBM. For an H100 at 3 TB/s bandwidth and 80GB HBM, each round-trip costs microseconds. These add up.

The computation is trivial—multiply, compare, mask. The memory traffic is expensive.

31.2 What Is Fusion?

Fusion combines multiple operations into a single kernel:

Fused kernel: Read x from HBM → Wx+b → relu → dropout → Write to HBM

One read, one write. The intermediates stay in registers or shared memory, never touching HBM.

The benefit scales with arithmetic intensity. Operations with low arithmetic intensity (element-wise, small reductions) benefit most because memory traffic dominates their runtime.

31.3 When Fusion Helps

Not all operations benefit equally from fusion.

31.3.1 High Benefit: Element-wise Chains

y = torch.sigmoid(x) * torch.tanh(x + 1)

Without fusion: 4 memory round-trips (read x twice, write twice) With fusion: 1 read, 1 write

Speedup: often 3-4× or more.

31.3.2 High Benefit: Small Reductions Following Compute

loss = F.cross_entropy(logits, labels)

Cross-entropy computes softmax then log then reduction. Fusing prevents materializing the softmax probabilities.

31.3.3 Moderate Benefit: LayerNorm and Softmax

x = F.layer_norm(x, normalized_shape)

LayerNorm computes mean, variance, normalization, and affine transform. Fusing keeps statistics in registers.

31.3.4 Low Benefit: Large Matrix Multiplies

y = x @ W

Matrix multiply already has high arithmetic intensity. It’s compute-bound, not memory-bound. Fusion helps the operations around matmul, not matmul itself.

31.4 The Fusion Hierarchy

Fusion can happen at different levels:

31.4.1 Level 1: Manual Fused Kernels

Hand-written CUDA kernels combining specific operations:

__global__ void fused_linear_relu(float* output, float* input,
                                   float* weight, float* bias, int n) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < n) {
        float val = 0;
        for (int j = 0; j < d; j++) {
            val += input[j] * weight[idx * d + j];
        }
        val += bias[idx];
        output[idx] = val > 0 ? val : 0;  // ReLU fused in
    }
}

Pros: Maximum control, maximum performance. Cons: Combinatorial explosion of kernel variants. Can’t fuse arbitrary combinations.

31.4.2 Level 2: Pattern-Based Fusion

Frameworks recognize common patterns and use pre-fused kernels:

# PyTorch recognizes this pattern
x = F.linear(x, weight, bias)
x = F.relu(x)
# May be fused automatically

Pros: Transparent to user, no code changes. Cons: Limited to recognized patterns.

31.4.3 Level 3: JIT Compilation (torch.compile)

Compilers analyze the computation graph and fuse automatically:

@torch.compile
def layer(x):
    x = F.linear(x, weight, bias)
    x = F.relu(x)
    x = F.dropout(x, p=0.1)
    return x

The compiler traces execution, builds a graph, and generates fused kernels.

Pros: Works with arbitrary code, finds fusion opportunities automatically. Cons: Compilation overhead, potential for unexpected behavior.

31.5 torch.compile Deep Dive

PyTorch 2.0 introduced torch.compile, a JIT compiler that automatically applies fusion and other optimizations.

31.5.1 How It Works

@torch.compile
def my_function(x, y):
    z = x + y
    return torch.relu(z)
  1. Tracing: TorchDynamo captures the Python bytecode and extracts a computation graph
  2. Graph capture: Operations are recorded into an FX graph
  3. Optimization: TorchInductor applies fusion and generates Triton kernels
  4. Code generation: Optimized kernels are compiled and cached

31.5.2 Fusion in Action

import torch

def unfused(x):
    x = x + 1
    x = x * 2
    x = torch.relu(x)
    return x

fused = torch.compile(unfused)

x = torch.randn(10000, device='cuda')

# First call: compilation happens
y = fused(x)

# Subsequent calls: use cached fused kernel
y = fused(x)  # Fast

31.5.3 What Gets Fused?

torch.compile typically fuses:

  • Chains of element-wise operations
  • Normalization layers (LayerNorm, BatchNorm, RMSNorm)
  • Softmax and cross-entropy
  • Pointwise operations following matmul (bias + activation)

31.5.4 What Doesn’t Fuse?

  • Large matrix multiplies (already efficient, use cuBLAS/cuDNN)
  • Operations with data-dependent control flow
  • Custom CUDA extensions (unless wrapped appropriately)

31.6 The Limits of Fusion

Fusion isn’t always beneficial.

31.6.1 Memory Limits

Fused kernels need registers and shared memory for intermediates. Very long fusion chains may spill to local memory, negating benefits.

31.6.2 Register Pressure

Each additional operation in a fused kernel requires registers. At some point, the GPU runs out:

# This might not fuse well due to register pressure
def many_ops(x):
    for i in range(20):
        x = x * 2 + 1
        x = torch.sin(x) + torch.cos(x)
    return x

31.6.3 Compilation Overhead

torch.compile has upfront cost:

  • First call: 100ms to several seconds
  • Cached calls: near-zero overhead

For short-running programs or highly dynamic code, compilation may not pay off.

31.6.4 Dynamic Shapes

Varying tensor shapes can trigger recompilation:

@torch.compile
def process(x):
    return x.sum(dim=-1)

# Each new shape may trigger recompilation
process(torch.randn(100, 50, device='cuda'))
process(torch.randn(200, 100, device='cuda'))  # Recompile?

Use dynamic=True or padding to mitigate.

31.7 Fusion and the Memory Hierarchy

Fusion is fundamentally about the memory hierarchy. Let’s trace an example:

31.7.1 Without Fusion

def unfused_attention_output(attn_weights, v, residual):
    # Kernel 1: matmul
    attn_out = attn_weights @ v  # Read attn_weights, v; Write attn_out to HBM

    # Kernel 2: residual add
    out = attn_out + residual    # Read attn_out, residual; Write out to HBM

    # Kernel 3: layer norm
    out = F.layer_norm(out, ...)  # Read out; Write normalized to HBM

    return out

Memory traffic: attn_out written once, read once (2 passes). Same for intermediate.

31.7.2 With Fusion

@torch.compile
def fused_attention_output(attn_weights, v, residual):
    attn_out = attn_weights @ v
    out = attn_out + residual
    out = F.layer_norm(out, ...)
    return out

The compiler may fuse the residual add and layer norm:

Kernel 1: matmul (uses optimized cuBLAS, writes to HBM)
Kernel 2 (fused): Read matmul output + residual → add → layernorm → Write final output

One fewer round-trip.

31.8 Practical Guidance

31.8.1 When to Use torch.compile

# Good candidates for torch.compile:

# 1. Inference on fixed shapes
@torch.compile(mode="reduce-overhead")
def inference(x):
    return model(x)

# 2. Training loops with stable shapes
model = torch.compile(model)

# 3. Custom loss functions with many small ops
@torch.compile
def custom_loss(pred, target):
    # Many element-wise operations that benefit from fusion
    ...

31.8.2 When to Avoid

# Bad candidates:

# 1. Highly dynamic control flow
def dynamic_model(x):
    if x.sum() > 0:  # Data-dependent branch
        return path_a(x)
    else:
        return path_b(x)

# 2. Very short functions (compilation cost > runtime benefit)
@torch.compile
def tiny(x):
    return x + 1  # Probably not worth it

# 3. Development/debugging (compilation hides errors)

31.8.3 Measuring Fusion Impact

import torch
import time

def benchmark(fn, x, n_warmup=10, n_iter=100):
    for _ in range(n_warmup):
        fn(x)
    torch.cuda.synchronize()

    start = time.perf_counter()
    for _ in range(n_iter):
        fn(x)
    torch.cuda.synchronize()

    return (time.perf_counter() - start) / n_iter * 1000  # ms

x = torch.randn(4096, 4096, device='cuda')

def chain(x):
    x = x + 1
    x = x * 2
    x = torch.relu(x)
    x = x ** 2
    return x

compiled = torch.compile(chain)

print(f"Eager:    {benchmark(chain, x):.3f} ms")
print(f"Compiled: {benchmark(compiled, x):.3f} ms")
# Typical result: 2-4× speedup

31.9 Fusion in FlashAttention

FlashAttention is the ultimate fusion example. It fuses the entire attention computation:

Standard:
  S = Q @ K.T    → Write S to HBM (n² elements)
  P = softmax(S) → Write P to HBM (n² elements)
  O = P @ V      → Write O to HBM (n×d elements)

FlashAttention:
  [Fused kernel]: Read Q, K, V → Compute everything in SRAM → Write O only

The “fusion” here is extreme: all intermediates stay in SRAM, never touching HBM. This is why FlashAttention is faster despite doing more compute (the streaming softmax corrections).

31.10 Connections

Chapter 2 (Bandwidth): Fusion reduces memory traffic, addressing the bandwidth bottleneck.

Chapter 7 (Locality): Fusion improves temporal locality—data is reused before eviction from registers/cache.

Chapter 10 (FlashAttention): The ultimate example of fusion principles applied to attention.

31.11 Key Takeaways

  1. Fusion eliminates memory round-trips: Every intermediate written to HBM and read back is a cost that fusion can avoid.

  2. Low arithmetic intensity operations benefit most: Element-wise ops, small reductions, normalization layers.

  3. torch.compile automates fusion: For most code, @torch.compile is the easiest way to get fusion benefits.

  4. Manual fusion still matters: For maximum performance (like FlashAttention), hand-written fused kernels remain important.

  5. Fusion has limits: Register pressure, compilation overhead, and dynamic shapes all constrain what can be fused.

NoteTry It Yourself

The accompanying notebook lets you:

  • Measure fusion speedups on element-wise chains
  • Compare eager vs. compiled execution
  • Visualize what torch.compile fuses
  • Profile memory traffic with and without fusion

Open In Colab