32 torch.compile Deep Dive
Understanding PyTorch’s JIT Compiler
One line of code. Up to 2x speedup. No model changes.
torch.compile is PyTorch’s answer to framework performance. Understanding how it works helps you get the most from it.
32.1 The Compilation Stack
When you call torch.compile(model), PyTorch invokes a sophisticated compilation pipeline:
Your PyTorch Code
↓
TorchDynamo (Graph Capture)
↓
AOTAutograd (Backward Graph)
↓
TorchInductor (Code Generation)
↓
Triton/C++/CUDA Kernels
Each stage transforms your code into increasingly optimized representations.
32.2 TorchDynamo: Graph Capture
32.2.1 The Problem with Eager Mode
PyTorch’s eager execution is flexible but slow:
def eager_forward(x, weight):
# Each line is a separate kernel launch
y = x @ weight # Kernel 1
y = y + bias # Kernel 2
y = torch.relu(y) # Kernel 3
y = y * scale # Kernel 4
return y # 4 kernel launches, 4 memory round-tripsEvery operation: 1. Launches a CUDA kernel 2. Writes output to global memory 3. Returns to Python
Overhead: Kernel launch latency + memory bandwidth + Python interpreter.
32.2.2 How Dynamo Works
TorchDynamo traces Python bytecode to capture operations:
import torch._dynamo as dynamo
def my_function(x):
y = x * 2
if y.sum() > 0: # Dynamic control flow!
y = y + 1
return y
# Dynamo traces through Python bytecode
explained = dynamo.explain(my_function, torch.randn(10))
print(explained)
# Output shows:
# - Which operations were captured as a graph
# - Where graph breaks occurred
# - Why breaks happened32.2.3 Graph Breaks
Some Python operations can’t be captured:
def problematic(x):
y = x * 2 # ✓ Captured
# Graph break: Python print
print(f"Shape: {y.shape}")
z = y + 1 # ✓ New graph starts
return zCommon graph break causes: 1. print() statements 2. Python data structures (lists, dicts with tensor values) 3. Unsupported operations 4. Dynamic control flow that can’t be specialized
Debugging breaks:
import torch._dynamo as dynamo
# See all graph breaks
dynamo.config.verbose = True
# Or use explain()
explanation = dynamo.explain(model, sample_input)
print(explanation)32.2.4 Specialization and Guards
Dynamo creates specialized graphs with “guards”:
def func(x, multiplier):
return x * multiplier
# First call: captures graph assuming multiplier=2
y = compiled_func(torch.randn(10), 2)
# Guard: "multiplier == 2"
# If guard fails, recompile
y = compiled_func(torch.randn(10), 3) # Guard fails → recompileShape specialization:
# Default: specialize on exact shapes
model = torch.compile(model)
model(torch.randn(32, 128)) # Compiles for shape [32, 128]
model(torch.randn(64, 128)) # Recompiles for [64, 128]
# Dynamic shapes: more flexible
model = torch.compile(model, dynamic=True)
model(torch.randn(32, 128)) # Compiles with symbolic shapes
model(torch.randn(64, 128)) # Reuses same compiled code32.3 AOTAutograd: Backward Graph Capture
32.3.1 Ahead-of-Time Autograd
Standard PyTorch builds the backward graph during forward. AOTAutograd captures both:
# Conceptually:
def aot_transform(forward_fn):
# Trace forward
forward_graph = trace(forward_fn)
# Generate backward graph from forward
backward_graph = generate_backward(forward_graph)
# Compile both
compiled_forward = compile(forward_graph)
compiled_backward = compile(backward_graph)
return compiled_forward, compiled_backwardBenefits: 1. Backward graph is known at compile time 2. Can optimize forward + backward jointly 3. No Python overhead during backward pass
32.3.2 Joint Graph Optimization
With both graphs, the compiler can:
# Forward saves intermediate for backward
def forward(x, weight):
y = x @ weight
save_for_backward(y) # Usually explicit
return relu(y)
# AOTAutograd can optimize:
# - Recompute cheap operations instead of saving
# - Fuse forward and backward operations
# - Eliminate unnecessary saves32.4 TorchInductor: Code Generation
32.4.1 From Graph to Kernels
Inductor transforms the captured graph into optimized code:
# Input graph:
# x → mul(2) → add(bias) → relu → output
# Inductor generates fused Triton kernel:
@triton.jit
def fused_kernel(x_ptr, bias_ptr, out_ptr, numel):
idx = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
mask = idx < numel
x = tl.load(x_ptr + idx, mask=mask)
x = x * 2
x = x + tl.load(bias_ptr + idx, mask=mask)
x = tl.maximum(x, 0) # ReLU
tl.store(out_ptr + idx, x, mask=mask)Fusion: Multiple operations → single kernel → one memory round-trip.
32.4.2 Inductor Optimizations
1. Operator Fusion
# Before: 4 kernels
y = x * 2
y = y + bias
y = torch.relu(y)
y = y * scale
# After: 1 fused kernel
# All operations in registers, single memory write2. Memory Format Optimization
# Inductor may reorder memory layout
# NCHW → NHWC for better memory access on convolutions3. Automatic Tuning
# Inductor tries multiple implementations
# Chooses fastest based on autotuning
# Example: matrix multiply
# Option A: Triton kernel
# Option B: cuBLAS
# Option C: Custom tiled implementation
# → Benchmark and select best32.4.3 Viewing Generated Code
import torch._inductor.config as inductor_config
# Save generated code
inductor_config.debug = True
model = torch.compile(model)
output = model(input)
# Generated code saved to /tmp/torchinductor_*/
# Includes Triton kernels, C++ code, etc.32.5 Compilation Modes
32.5.1 Mode Comparison
# Default: balance between compile time and performance
model = torch.compile(model)
# Reduce-overhead: minimize Python overhead
model = torch.compile(model, mode="reduce-overhead")
# Uses CUDA graphs for even less overhead
# Best for: inference, fixed shapes
# Max-autotune: maximum optimization
model = torch.compile(model, mode="max-autotune")
# Spends more time tuning
# Best for: training workloads run many times
# Max-autotune-no-cudagraphs
model = torch.compile(model, mode="max-autotune-no-cudagraphs")
# Maximum optimization without CUDA graphs
# Best for: dynamic shapes or memory constraints32.5.2 Backend Selection
# Default backend: Inductor
model = torch.compile(model, backend="inductor")
# For debugging: eager backend (no optimization)
model = torch.compile(model, backend="eager")
# TensorRT integration
model = torch.compile(model, backend="tensorrt")
# Custom backend
def my_backend(gm, example_inputs):
# gm is a torch.fx.GraphModule
# Return optimized callable
return optimized_function
model = torch.compile(model, backend=my_backend)32.6 Practical Usage Patterns
32.6.1 Basic Compilation
import torch
model = MyModel()
# Compile the whole model
compiled_model = torch.compile(model)
# Use normally
output = compiled_model(input)32.6.2 Selective Compilation
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def forward(self, x):
# Only compile the expensive part
encoded = torch.compile(self.encoder)(x)
decoded = self.decoder(encoded) # Not compiled
return decoded32.6.3 Compiling Training Loops
@torch.compile
def train_step(model, optimizer, batch):
optimizer.zero_grad()
loss = model(batch).sum()
loss.backward()
optimizer.step()
return loss
# First call compiles, subsequent calls are fast
for batch in dataloader:
loss = train_step(model, optimizer, batch)32.6.4 Handling Dynamic Shapes
# Static shapes (default): fastest, least flexible
model = torch.compile(model)
# Dynamic shapes: slower compile, handles varying shapes
model = torch.compile(model, dynamic=True)
# Mark specific dimensions as dynamic
from torch._dynamo import mark_dynamic
def forward(self, x):
# First dimension (batch) is dynamic
torch._dynamo.mark_dynamic(x, 0)
return self.layers(x)32.7 Debugging Compilation
32.7.1 Common Issues
1. Graph Breaks
# Problem: Print causes graph break
def forward(self, x):
y = self.linear(x)
print(f"Output shape: {y.shape}") # Graph break!
return y
# Solution: Use torch._dynamo.config.suppress_errors
# Or remove prints in production code2. Recompilation
# Problem: Different shapes cause recompilation
for batch_size in [16, 32, 64, 128]:
x = torch.randn(batch_size, 512)
model(x) # Recompiles each time!
# Solution: Use dynamic=True or pad to fixed size
model = torch.compile(model, dynamic=True)3. Slow Compilation
# Problem: First call takes minutes
model = torch.compile(model, mode="max-autotune")
model(x) # Very slow first call
# Solutions:
# 1. Use default mode for faster compile
model = torch.compile(model) # Faster compile, good performance
# 2. Warm up once, save compiled model
torch.save(model.state_dict(), "compiled_model.pt")32.7.2 Diagnostic Tools
# Full explanation of what happened
import torch._dynamo as dynamo
explanation = dynamo.explain(model, sample_input)
print(explanation)
# Compilation time breakdown
torch._inductor.config.profile_bandwidth = True
# Check for recompilations
torch._dynamo.config.cache_size_limit = 8 # Default
# If you see "cache limit reached", you have recompilation issues32.8 Performance Tips
32.8.1 Maximizing Speedup
# 1. Avoid graph breaks
# Bad:
def forward(self, x):
for i, layer in enumerate(self.layers):
x = layer(x)
if i == 5:
print(x.mean()) # Graph break
return x
# Good:
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x # No breaks
# 2. Use consistent shapes
# Bad: Variable sequence lengths
for seq_len in [100, 200, 300]:
x = torch.randn(batch, seq_len, hidden)
model(x) # Recompiles
# Good: Pad to max length
x = torch.randn(batch, max_seq_len, hidden)
x = x[:, :actual_seq_len, :] # View, no recompile
# 3. Fullgraph mode for maximum fusion
model = torch.compile(model, fullgraph=True)
# Errors if any graph breaks—ensures full optimization32.8.2 When torch.compile Helps Most
High speedup scenarios:
✓ Many small operations (element-wise fusion)
✓ Custom Python code (reduces interpreter overhead)
✓ Memory-bound operations (fusion reduces memory traffic)
Lower speedup scenarios:
✗ Mostly large matmuls (already optimized by cuBLAS)
✗ Heavy I/O or data loading (not compute-bound)
✗ Highly dynamic control flow
32.9 Integration with Other Tools
32.9.1 With Distributed Training
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
model = MyModel()
model = torch.compile(model) # Compile first
model = DDP(model) # Then wrap with DDP32.9.2 With Mixed Precision
from torch.cuda.amp import autocast
model = torch.compile(model)
with autocast():
output = model(input) # Works with AMP32.9.3 With Activation Checkpointing
from torch.utils.checkpoint import checkpoint
class Model(nn.Module):
def forward(self, x):
# Checkpointing works with compile
x = checkpoint(self.block1, x, use_reentrant=False)
x = self.block2(x)
return x
model = torch.compile(Model())32.10 The Future: torch.export
32.10.1 Ahead-of-Time Export
# New in PyTorch 2.1+
import torch.export
# Export to a standalone artifact
exported = torch.export.export(model, (sample_input,))
# Can be serialized
torch.export.save(exported, "model.pt2")
# Load and run without Python
loaded = torch.export.load("model.pt2")
output = loaded(input)Use cases: - Mobile deployment - Serverless inference - Edge devices - Non-Python runtimes
32.11 Key Takeaways
One line, real speedup:
torch.compile(model)often gives 1.5-2x speedup.Understand the stack: Dynamo captures, AOTAutograd handles backward, Inductor generates code.
Avoid graph breaks: Print statements, unsupported ops, and dynamic control flow cause breaks.
Use dynamic shapes wisely: Static shapes are faster but less flexible.
Mode matters:
reduce-overheadfor inference,max-autotunefor training.Debug with explain(): Understand what’s happening under the hood.
Compile first, distribute second: Apply
torch.compilebefore DDP/FSDP.
32.12 Further Reading
- PyTorch 2.0 Introduction
- TorchDynamo Deep Dive
- TorchInductor Design
- Ansel et al. (2024). “PyTorch 2: Faster Machine Learning Through Dynamic Python Bytecode Transformation and Graph Compilation”