32 Fusion
When Combining Operations Beats Separating Them
Every memory round-trip has a cost. If you write to memory just to read it back, you’ve paid twice for nothing.
Fusion eliminates these round-trips by combining operations into a single pass.
Fusion is primarily a compiler/runtime optimization; later tooling chapters show how modern stacks implement it.
32.2 What Is Fusion?
Fusion combines multiple operations into a single kernel:
Fused kernel: Read x from HBM → Wx+b → relu → dropout → Write to HBM
One read, one write. The intermediates stay in registers or shared memory, never touching HBM.
The benefit scales with arithmetic intensity. Operations with low arithmetic intensity (element-wise, small reductions) benefit most because memory traffic dominates their runtime.
32.3 When Fusion Helps
Not all operations benefit equally from fusion.
32.3.1 High Benefit: Element-wise Chains
y = torch.sigmoid(x) * torch.tanh(x + 1)Without fusion: 4 memory round-trips (read x twice, write twice) With fusion: 1 read, 1 write
Speedup: often 3-4× or more.
32.3.2 High Benefit: Small Reductions Following Compute
loss = F.cross_entropy(logits, labels)Cross-entropy computes softmax then log then reduction. Fusing prevents materializing the softmax probabilities.
32.3.3 Moderate Benefit: LayerNorm and Softmax
x = F.layer_norm(x, normalized_shape)LayerNorm computes mean, variance, normalization, and affine transform. Fusing keeps statistics in registers.
32.3.4 Low Benefit: Large Matrix Multiplies
y = x @ WMatrix multiply already has high arithmetic intensity. It’s compute-bound, not memory-bound. Fusion helps the operations around matmul, not matmul itself.
32.4 The Fusion Hierarchy
Fusion can happen at different levels:
32.4.1 Level 1: Manual Fused Kernels
Hand-written CUDA kernels combining specific operations:
__global__ void fused_linear_relu(float* output, float* input,
float* weight, float* bias, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
float val = 0;
for (int j = 0; j < d; j++) {
val += input[j] * weight[idx * d + j];
}
val += bias[idx];
output[idx] = val > 0 ? val : 0; // ReLU fused in
}
}
Pros: Maximum control, maximum performance. Cons: Combinatorial explosion of kernel variants. Can’t fuse arbitrary combinations.
32.4.2 Level 2: Pattern-Based Fusion
Frameworks recognize common patterns and use pre-fused kernels:
# PyTorch recognizes this pattern
x = F.linear(x, weight, bias)
x = F.relu(x)
# May be fused automaticallyPros: Transparent to user, no code changes. Cons: Limited to recognized patterns.
32.4.3 Level 3: JIT Compilation (torch.compile)
Compilers analyze the computation graph and fuse automatically:
@torch.compile
def layer(x):
x = F.linear(x, weight, bias)
x = F.relu(x)
x = F.dropout(x, p=0.1)
return xThe compiler traces execution, builds a graph, and generates fused kernels.
Pros: Works with arbitrary code, finds fusion opportunities automatically. Cons: Compilation overhead, potential for unexpected behavior.
32.5 Automatic Fusion with torch.compile
PyTorch 2.0 introduced torch.compile, which automatically identifies and applies fusion opportunities:
@torch.compile
def my_function(x, y):
z = x + y # These element-wise ops
z = torch.relu(z) # get fused into a single kernel
return ztorch.compile uses TorchDynamo to trace Python bytecode into a graph, then TorchInductor generates fused Triton kernels. It typically fuses element-wise chains, normalization layers, and pointwise ops following matmul—but leaves large matrix multiplies to cuBLAS.
For a comprehensive treatment of torch.compile—including graph breaks, compilation modes, dynamic shapes, and integration patterns—see the torch.compile Deep Dive.
32.6 The Limits of Fusion
Fusion isn’t always beneficial.
Register pressure: Each additional operation in a fused kernel requires registers. Very long fusion chains may spill to local memory, negating benefits.
Compilation overhead: torch.compile has upfront cost (100ms to several seconds for first call). For short-running or highly dynamic code, this may not pay off.
Dynamic shapes: Varying tensor shapes can trigger recompilation. Use torch.compile(dynamic=True) or padding to mitigate.
32.7 Fusion and the Memory Hierarchy
Fusion is fundamentally about the memory hierarchy. Let’s trace an example:
32.7.1 Without Fusion
def unfused_attention_output(attn_weights, v, residual):
# Kernel 1: matmul
attn_out = attn_weights @ v # Read attn_weights, v; Write attn_out to HBM
# Kernel 2: residual add
out = attn_out + residual # Read attn_out, residual; Write out to HBM
# Kernel 3: layer norm
out = F.layer_norm(out, ...) # Read out; Write normalized to HBM
return outMemory traffic: attn_out written once, read once (2 passes). Same for intermediate.
32.7.2 With Fusion
@torch.compile
def fused_attention_output(attn_weights, v, residual):
attn_out = attn_weights @ v
out = attn_out + residual
out = F.layer_norm(out, ...)
return outThe compiler may fuse the residual add and layer norm:
Kernel 1: matmul (uses optimized cuBLAS, writes to HBM)
Kernel 2 (fused): Read matmul output + residual → add → layernorm → Write final output
One fewer round-trip.
32.8 Practical Guidance
32.8.1 When to Use torch.compile
# Good candidates for torch.compile:
# 1. Inference on fixed shapes
@torch.compile(mode="reduce-overhead")
def inference(x):
return model(x)
# 2. Training loops with stable shapes
model = torch.compile(model)
# 3. Custom loss functions with many small ops
@torch.compile
def custom_loss(pred, target):
# Many element-wise operations that benefit from fusion
...32.8.2 When to Avoid
# Bad candidates:
# 1. Highly dynamic control flow
def dynamic_model(x):
if x.sum() > 0: # Data-dependent branch
return path_a(x)
else:
return path_b(x)
# 2. Very short functions (compilation cost > runtime benefit)
@torch.compile
def tiny(x):
return x + 1 # Probably not worth it
# 3. Development/debugging (compilation hides errors)32.8.3 Measuring Fusion Impact
import torch
import time
def benchmark(fn, x, n_warmup=10, n_iter=100):
for _ in range(n_warmup):
fn(x)
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(n_iter):
fn(x)
torch.cuda.synchronize()
return (time.perf_counter() - start) / n_iter * 1000 # ms
x = torch.randn(4096, 4096, device='cuda')
def chain(x):
x = x + 1
x = x * 2
x = torch.relu(x)
x = x ** 2
return x
compiled = torch.compile(chain)
print(f"Eager: {benchmark(chain, x):.3f} ms")
print(f"Compiled: {benchmark(compiled, x):.3f} ms")
# Typical result: 2-4× speedup32.9 Fusion in FlashAttention
FlashAttention is the ultimate fusion example. It fuses the entire attention computation:
Standard:
S = Q @ K.T → Write S to HBM (n² elements)
P = softmax(S) → Write P to HBM (n² elements)
O = P @ V → Write O to HBM (n×d elements)
FlashAttention:
[Fused kernel]: Read Q, K, V → Compute everything in SRAM → Write O only
The “fusion” here is extreme: all intermediates stay in SRAM, never touching HBM. This is why FlashAttention is faster despite doing more compute (the streaming softmax corrections).
32.10 Connections
Bandwidth: Fusion reduces memory traffic, addressing the bandwidth bottleneck.
Locality: Fusion improves temporal locality—data is reused before eviction from registers/cache.
FlashAttention: The ultimate example of fusion principles applied to attention.
32.11 Key Takeaways
Fusion eliminates memory round-trips: Every intermediate written to HBM and read back is a cost that fusion can avoid.
Low arithmetic intensity operations benefit most: Element-wise ops, small reductions, normalization layers.
torch.compile automates fusion: For most code,
@torch.compileis the easiest way to get fusion benefits.Manual fusion still matters: For maximum performance (like FlashAttention), hand-written fused kernels remain important.
Fusion has limits: Register pressure, compilation overhead, and dynamic shapes all constrain what can be fused.