31 Fusion
When Combining Operations Beats Separating Them
Every memory round-trip has a cost. If you write to memory just to read it back, you’ve paid twice for nothing.
Fusion eliminates these round-trips by combining operations into a single pass.
31.2 What Is Fusion?
Fusion combines multiple operations into a single kernel:
Fused kernel: Read x from HBM → Wx+b → relu → dropout → Write to HBM
One read, one write. The intermediates stay in registers or shared memory, never touching HBM.
The benefit scales with arithmetic intensity. Operations with low arithmetic intensity (element-wise, small reductions) benefit most because memory traffic dominates their runtime.
31.3 When Fusion Helps
Not all operations benefit equally from fusion.
31.3.1 High Benefit: Element-wise Chains
y = torch.sigmoid(x) * torch.tanh(x + 1)Without fusion: 4 memory round-trips (read x twice, write twice) With fusion: 1 read, 1 write
Speedup: often 3-4× or more.
31.3.2 High Benefit: Small Reductions Following Compute
loss = F.cross_entropy(logits, labels)Cross-entropy computes softmax then log then reduction. Fusing prevents materializing the softmax probabilities.
31.3.3 Moderate Benefit: LayerNorm and Softmax
x = F.layer_norm(x, normalized_shape)LayerNorm computes mean, variance, normalization, and affine transform. Fusing keeps statistics in registers.
31.3.4 Low Benefit: Large Matrix Multiplies
y = x @ WMatrix multiply already has high arithmetic intensity. It’s compute-bound, not memory-bound. Fusion helps the operations around matmul, not matmul itself.
31.4 The Fusion Hierarchy
Fusion can happen at different levels:
31.4.1 Level 1: Manual Fused Kernels
Hand-written CUDA kernels combining specific operations:
__global__ void fused_linear_relu(float* output, float* input,
float* weight, float* bias, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
float val = 0;
for (int j = 0; j < d; j++) {
val += input[j] * weight[idx * d + j];
}
val += bias[idx];
output[idx] = val > 0 ? val : 0; // ReLU fused in
}
}
Pros: Maximum control, maximum performance. Cons: Combinatorial explosion of kernel variants. Can’t fuse arbitrary combinations.
31.4.2 Level 2: Pattern-Based Fusion
Frameworks recognize common patterns and use pre-fused kernels:
# PyTorch recognizes this pattern
x = F.linear(x, weight, bias)
x = F.relu(x)
# May be fused automaticallyPros: Transparent to user, no code changes. Cons: Limited to recognized patterns.
31.4.3 Level 3: JIT Compilation (torch.compile)
Compilers analyze the computation graph and fuse automatically:
@torch.compile
def layer(x):
x = F.linear(x, weight, bias)
x = F.relu(x)
x = F.dropout(x, p=0.1)
return xThe compiler traces execution, builds a graph, and generates fused kernels.
Pros: Works with arbitrary code, finds fusion opportunities automatically. Cons: Compilation overhead, potential for unexpected behavior.
31.5 torch.compile Deep Dive
PyTorch 2.0 introduced torch.compile, a JIT compiler that automatically applies fusion and other optimizations.
31.5.1 How It Works
@torch.compile
def my_function(x, y):
z = x + y
return torch.relu(z)- Tracing: TorchDynamo captures the Python bytecode and extracts a computation graph
- Graph capture: Operations are recorded into an FX graph
- Optimization: TorchInductor applies fusion and generates Triton kernels
- Code generation: Optimized kernels are compiled and cached
31.5.2 Fusion in Action
import torch
def unfused(x):
x = x + 1
x = x * 2
x = torch.relu(x)
return x
fused = torch.compile(unfused)
x = torch.randn(10000, device='cuda')
# First call: compilation happens
y = fused(x)
# Subsequent calls: use cached fused kernel
y = fused(x) # Fast31.5.3 What Gets Fused?
torch.compile typically fuses:
- Chains of element-wise operations
- Normalization layers (LayerNorm, BatchNorm, RMSNorm)
- Softmax and cross-entropy
- Pointwise operations following matmul (bias + activation)
31.5.4 What Doesn’t Fuse?
- Large matrix multiplies (already efficient, use cuBLAS/cuDNN)
- Operations with data-dependent control flow
- Custom CUDA extensions (unless wrapped appropriately)
31.6 The Limits of Fusion
Fusion isn’t always beneficial.
31.6.1 Memory Limits
Fused kernels need registers and shared memory for intermediates. Very long fusion chains may spill to local memory, negating benefits.
31.6.2 Register Pressure
Each additional operation in a fused kernel requires registers. At some point, the GPU runs out:
# This might not fuse well due to register pressure
def many_ops(x):
for i in range(20):
x = x * 2 + 1
x = torch.sin(x) + torch.cos(x)
return x31.6.3 Compilation Overhead
torch.compile has upfront cost:
- First call: 100ms to several seconds
- Cached calls: near-zero overhead
For short-running programs or highly dynamic code, compilation may not pay off.
31.6.4 Dynamic Shapes
Varying tensor shapes can trigger recompilation:
@torch.compile
def process(x):
return x.sum(dim=-1)
# Each new shape may trigger recompilation
process(torch.randn(100, 50, device='cuda'))
process(torch.randn(200, 100, device='cuda')) # Recompile?Use dynamic=True or padding to mitigate.
31.7 Fusion and the Memory Hierarchy
Fusion is fundamentally about the memory hierarchy. Let’s trace an example:
31.7.1 Without Fusion
def unfused_attention_output(attn_weights, v, residual):
# Kernel 1: matmul
attn_out = attn_weights @ v # Read attn_weights, v; Write attn_out to HBM
# Kernel 2: residual add
out = attn_out + residual # Read attn_out, residual; Write out to HBM
# Kernel 3: layer norm
out = F.layer_norm(out, ...) # Read out; Write normalized to HBM
return outMemory traffic: attn_out written once, read once (2 passes). Same for intermediate.
31.7.2 With Fusion
@torch.compile
def fused_attention_output(attn_weights, v, residual):
attn_out = attn_weights @ v
out = attn_out + residual
out = F.layer_norm(out, ...)
return outThe compiler may fuse the residual add and layer norm:
Kernel 1: matmul (uses optimized cuBLAS, writes to HBM)
Kernel 2 (fused): Read matmul output + residual → add → layernorm → Write final output
One fewer round-trip.
31.8 Practical Guidance
31.8.1 When to Use torch.compile
# Good candidates for torch.compile:
# 1. Inference on fixed shapes
@torch.compile(mode="reduce-overhead")
def inference(x):
return model(x)
# 2. Training loops with stable shapes
model = torch.compile(model)
# 3. Custom loss functions with many small ops
@torch.compile
def custom_loss(pred, target):
# Many element-wise operations that benefit from fusion
...31.8.2 When to Avoid
# Bad candidates:
# 1. Highly dynamic control flow
def dynamic_model(x):
if x.sum() > 0: # Data-dependent branch
return path_a(x)
else:
return path_b(x)
# 2. Very short functions (compilation cost > runtime benefit)
@torch.compile
def tiny(x):
return x + 1 # Probably not worth it
# 3. Development/debugging (compilation hides errors)31.8.3 Measuring Fusion Impact
import torch
import time
def benchmark(fn, x, n_warmup=10, n_iter=100):
for _ in range(n_warmup):
fn(x)
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(n_iter):
fn(x)
torch.cuda.synchronize()
return (time.perf_counter() - start) / n_iter * 1000 # ms
x = torch.randn(4096, 4096, device='cuda')
def chain(x):
x = x + 1
x = x * 2
x = torch.relu(x)
x = x ** 2
return x
compiled = torch.compile(chain)
print(f"Eager: {benchmark(chain, x):.3f} ms")
print(f"Compiled: {benchmark(compiled, x):.3f} ms")
# Typical result: 2-4× speedup31.9 Fusion in FlashAttention
FlashAttention is the ultimate fusion example. It fuses the entire attention computation:
Standard:
S = Q @ K.T → Write S to HBM (n² elements)
P = softmax(S) → Write P to HBM (n² elements)
O = P @ V → Write O to HBM (n×d elements)
FlashAttention:
[Fused kernel]: Read Q, K, V → Compute everything in SRAM → Write O only
The “fusion” here is extreme: all intermediates stay in SRAM, never touching HBM. This is why FlashAttention is faster despite doing more compute (the streaming softmax corrections).
31.10 Connections
Chapter 2 (Bandwidth): Fusion reduces memory traffic, addressing the bandwidth bottleneck.
Chapter 7 (Locality): Fusion improves temporal locality—data is reused before eviction from registers/cache.
Chapter 10 (FlashAttention): The ultimate example of fusion principles applied to attention.
31.11 Key Takeaways
Fusion eliminates memory round-trips: Every intermediate written to HBM and read back is a cost that fusion can avoid.
Low arithmetic intensity operations benefit most: Element-wise ops, small reductions, normalization layers.
torch.compile automates fusion: For most code,
@torch.compileis the easiest way to get fusion benefits.Manual fusion still matters: For maximum performance (like FlashAttention), hand-written fused kernels remain important.
Fusion has limits: Register pressure, compilation overhead, and dynamic shapes all constrain what can be fused.