32  torch.compile Deep Dive

Understanding PyTorch’s JIT Compiler


One line of code. Up to 2x speedup. No model changes.

torch.compile is PyTorch’s answer to framework performance. Understanding how it works helps you get the most from it.

32.1 The Compilation Stack

When you call torch.compile(model), PyTorch invokes a sophisticated compilation pipeline:

Your PyTorch Code
       ↓
  TorchDynamo (Graph Capture)
       ↓
  AOTAutograd (Backward Graph)
       ↓
  TorchInductor (Code Generation)
       ↓
  Triton/C++/CUDA Kernels

Each stage transforms your code into increasingly optimized representations.

32.2 TorchDynamo: Graph Capture

32.2.1 The Problem with Eager Mode

PyTorch’s eager execution is flexible but slow:

def eager_forward(x, weight):
    # Each line is a separate kernel launch
    y = x @ weight        # Kernel 1
    y = y + bias          # Kernel 2
    y = torch.relu(y)     # Kernel 3
    y = y * scale         # Kernel 4
    return y              # 4 kernel launches, 4 memory round-trips

Every operation: 1. Launches a CUDA kernel 2. Writes output to global memory 3. Returns to Python

Overhead: Kernel launch latency + memory bandwidth + Python interpreter.

32.2.2 How Dynamo Works

TorchDynamo traces Python bytecode to capture operations:

import torch._dynamo as dynamo

def my_function(x):
    y = x * 2
    if y.sum() > 0:  # Dynamic control flow!
        y = y + 1
    return y

# Dynamo traces through Python bytecode
explained = dynamo.explain(my_function, torch.randn(10))
print(explained)

# Output shows:
# - Which operations were captured as a graph
# - Where graph breaks occurred
# - Why breaks happened

32.2.3 Graph Breaks

Some Python operations can’t be captured:

def problematic(x):
    y = x * 2  # ✓ Captured

    # Graph break: Python print
    print(f"Shape: {y.shape}")

    z = y + 1  # ✓ New graph starts
    return z

Common graph break causes: 1. print() statements 2. Python data structures (lists, dicts with tensor values) 3. Unsupported operations 4. Dynamic control flow that can’t be specialized

Debugging breaks:

import torch._dynamo as dynamo

# See all graph breaks
dynamo.config.verbose = True

# Or use explain()
explanation = dynamo.explain(model, sample_input)
print(explanation)

32.2.4 Specialization and Guards

Dynamo creates specialized graphs with “guards”:

def func(x, multiplier):
    return x * multiplier

# First call: captures graph assuming multiplier=2
y = compiled_func(torch.randn(10), 2)

# Guard: "multiplier == 2"
# If guard fails, recompile

y = compiled_func(torch.randn(10), 3)  # Guard fails → recompile

Shape specialization:

# Default: specialize on exact shapes
model = torch.compile(model)
model(torch.randn(32, 128))  # Compiles for shape [32, 128]
model(torch.randn(64, 128))  # Recompiles for [64, 128]

# Dynamic shapes: more flexible
model = torch.compile(model, dynamic=True)
model(torch.randn(32, 128))  # Compiles with symbolic shapes
model(torch.randn(64, 128))  # Reuses same compiled code

32.3 AOTAutograd: Backward Graph Capture

32.3.1 Ahead-of-Time Autograd

Standard PyTorch builds the backward graph during forward. AOTAutograd captures both:

# Conceptually:
def aot_transform(forward_fn):
    # Trace forward
    forward_graph = trace(forward_fn)

    # Generate backward graph from forward
    backward_graph = generate_backward(forward_graph)

    # Compile both
    compiled_forward = compile(forward_graph)
    compiled_backward = compile(backward_graph)

    return compiled_forward, compiled_backward

Benefits: 1. Backward graph is known at compile time 2. Can optimize forward + backward jointly 3. No Python overhead during backward pass

32.3.2 Joint Graph Optimization

With both graphs, the compiler can:

# Forward saves intermediate for backward
def forward(x, weight):
    y = x @ weight
    save_for_backward(y)  # Usually explicit
    return relu(y)

# AOTAutograd can optimize:
# - Recompute cheap operations instead of saving
# - Fuse forward and backward operations
# - Eliminate unnecessary saves

32.4 TorchInductor: Code Generation

32.4.1 From Graph to Kernels

Inductor transforms the captured graph into optimized code:

# Input graph:
# x → mul(2) → add(bias) → relu → output

# Inductor generates fused Triton kernel:
@triton.jit
def fused_kernel(x_ptr, bias_ptr, out_ptr, numel):
    idx = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
    mask = idx < numel

    x = tl.load(x_ptr + idx, mask=mask)
    x = x * 2
    x = x + tl.load(bias_ptr + idx, mask=mask)
    x = tl.maximum(x, 0)  # ReLU
    tl.store(out_ptr + idx, x, mask=mask)

Fusion: Multiple operations → single kernel → one memory round-trip.

32.4.2 Inductor Optimizations

1. Operator Fusion

# Before: 4 kernels
y = x * 2
y = y + bias
y = torch.relu(y)
y = y * scale

# After: 1 fused kernel
# All operations in registers, single memory write

2. Memory Format Optimization

# Inductor may reorder memory layout
# NCHW → NHWC for better memory access on convolutions

3. Automatic Tuning

# Inductor tries multiple implementations
# Chooses fastest based on autotuning

# Example: matrix multiply
# Option A: Triton kernel
# Option B: cuBLAS
# Option C: Custom tiled implementation
# → Benchmark and select best

32.4.3 Viewing Generated Code

import torch._inductor.config as inductor_config

# Save generated code
inductor_config.debug = True

model = torch.compile(model)
output = model(input)

# Generated code saved to /tmp/torchinductor_*/
# Includes Triton kernels, C++ code, etc.

32.5 Compilation Modes

32.5.1 Mode Comparison

# Default: balance between compile time and performance
model = torch.compile(model)

# Reduce-overhead: minimize Python overhead
model = torch.compile(model, mode="reduce-overhead")
# Uses CUDA graphs for even less overhead
# Best for: inference, fixed shapes

# Max-autotune: maximum optimization
model = torch.compile(model, mode="max-autotune")
# Spends more time tuning
# Best for: training workloads run many times

# Max-autotune-no-cudagraphs
model = torch.compile(model, mode="max-autotune-no-cudagraphs")
# Maximum optimization without CUDA graphs
# Best for: dynamic shapes or memory constraints

32.5.2 Backend Selection

# Default backend: Inductor
model = torch.compile(model, backend="inductor")

# For debugging: eager backend (no optimization)
model = torch.compile(model, backend="eager")

# TensorRT integration
model = torch.compile(model, backend="tensorrt")

# Custom backend
def my_backend(gm, example_inputs):
    # gm is a torch.fx.GraphModule
    # Return optimized callable
    return optimized_function

model = torch.compile(model, backend=my_backend)

32.6 Practical Usage Patterns

32.6.1 Basic Compilation

import torch

model = MyModel()

# Compile the whole model
compiled_model = torch.compile(model)

# Use normally
output = compiled_model(input)

32.6.2 Selective Compilation

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        # Only compile the expensive part
        encoded = torch.compile(self.encoder)(x)
        decoded = self.decoder(encoded)  # Not compiled
        return decoded

32.6.3 Compiling Training Loops

@torch.compile
def train_step(model, optimizer, batch):
    optimizer.zero_grad()
    loss = model(batch).sum()
    loss.backward()
    optimizer.step()
    return loss

# First call compiles, subsequent calls are fast
for batch in dataloader:
    loss = train_step(model, optimizer, batch)

32.6.4 Handling Dynamic Shapes

# Static shapes (default): fastest, least flexible
model = torch.compile(model)

# Dynamic shapes: slower compile, handles varying shapes
model = torch.compile(model, dynamic=True)

# Mark specific dimensions as dynamic
from torch._dynamo import mark_dynamic

def forward(self, x):
    # First dimension (batch) is dynamic
    torch._dynamo.mark_dynamic(x, 0)
    return self.layers(x)

32.7 Debugging Compilation

32.7.1 Common Issues

1. Graph Breaks

# Problem: Print causes graph break
def forward(self, x):
    y = self.linear(x)
    print(f"Output shape: {y.shape}")  # Graph break!
    return y

# Solution: Use torch._dynamo.config.suppress_errors
# Or remove prints in production code

2. Recompilation

# Problem: Different shapes cause recompilation
for batch_size in [16, 32, 64, 128]:
    x = torch.randn(batch_size, 512)
    model(x)  # Recompiles each time!

# Solution: Use dynamic=True or pad to fixed size
model = torch.compile(model, dynamic=True)

3. Slow Compilation

# Problem: First call takes minutes
model = torch.compile(model, mode="max-autotune")
model(x)  # Very slow first call

# Solutions:
# 1. Use default mode for faster compile
model = torch.compile(model)  # Faster compile, good performance

# 2. Warm up once, save compiled model
torch.save(model.state_dict(), "compiled_model.pt")

32.7.2 Diagnostic Tools

# Full explanation of what happened
import torch._dynamo as dynamo
explanation = dynamo.explain(model, sample_input)
print(explanation)

# Compilation time breakdown
torch._inductor.config.profile_bandwidth = True

# Check for recompilations
torch._dynamo.config.cache_size_limit = 8  # Default
# If you see "cache limit reached", you have recompilation issues

32.8 Performance Tips

32.8.1 Maximizing Speedup

# 1. Avoid graph breaks
# Bad:
def forward(self, x):
    for i, layer in enumerate(self.layers):
        x = layer(x)
        if i == 5:
            print(x.mean())  # Graph break
    return x

# Good:
def forward(self, x):
    for layer in self.layers:
        x = layer(x)
    return x  # No breaks

# 2. Use consistent shapes
# Bad: Variable sequence lengths
for seq_len in [100, 200, 300]:
    x = torch.randn(batch, seq_len, hidden)
    model(x)  # Recompiles

# Good: Pad to max length
x = torch.randn(batch, max_seq_len, hidden)
x = x[:, :actual_seq_len, :]  # View, no recompile

# 3. Fullgraph mode for maximum fusion
model = torch.compile(model, fullgraph=True)
# Errors if any graph breaks—ensures full optimization

32.8.2 When torch.compile Helps Most

High speedup scenarios:
✓ Many small operations (element-wise fusion)
✓ Custom Python code (reduces interpreter overhead)
✓ Memory-bound operations (fusion reduces memory traffic)

Lower speedup scenarios:
✗ Mostly large matmuls (already optimized by cuBLAS)
✗ Heavy I/O or data loading (not compute-bound)
✗ Highly dynamic control flow

32.9 Integration with Other Tools

32.9.1 With Distributed Training

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

model = MyModel()
model = torch.compile(model)  # Compile first
model = DDP(model)  # Then wrap with DDP

32.9.2 With Mixed Precision

from torch.cuda.amp import autocast

model = torch.compile(model)

with autocast():
    output = model(input)  # Works with AMP

32.9.3 With Activation Checkpointing

from torch.utils.checkpoint import checkpoint

class Model(nn.Module):
    def forward(self, x):
        # Checkpointing works with compile
        x = checkpoint(self.block1, x, use_reentrant=False)
        x = self.block2(x)
        return x

model = torch.compile(Model())

32.10 The Future: torch.export

32.10.1 Ahead-of-Time Export

# New in PyTorch 2.1+
import torch.export

# Export to a standalone artifact
exported = torch.export.export(model, (sample_input,))

# Can be serialized
torch.export.save(exported, "model.pt2")

# Load and run without Python
loaded = torch.export.load("model.pt2")
output = loaded(input)

Use cases: - Mobile deployment - Serverless inference - Edge devices - Non-Python runtimes

32.11 Key Takeaways

  1. One line, real speedup: torch.compile(model) often gives 1.5-2x speedup.

  2. Understand the stack: Dynamo captures, AOTAutograd handles backward, Inductor generates code.

  3. Avoid graph breaks: Print statements, unsupported ops, and dynamic control flow cause breaks.

  4. Use dynamic shapes wisely: Static shapes are faster but less flexible.

  5. Mode matters: reduce-overhead for inference, max-autotune for training.

  6. Debug with explain(): Understand what’s happening under the hood.

  7. Compile first, distribute second: Apply torch.compile before DDP/FSDP.

NoteTry It Yourself

The accompanying notebook lets you:

  • Measure speedup on different model types
  • Debug graph breaks with explain()
  • Compare compilation modes
  • View generated Triton code

Open In Colab

32.12 Further Reading