12 Reversibility
How Invertibility Enables Memory-Efficient Training
You’re training a 96-layer transformer. Each layer stores activations for the backward pass.
That’s 96 copies of intermediate state, each the size of your batch times hidden dimension. On a 40GB A100, you run out of memory before you run out of model.
What if you could train with just one copy—and reconstruct the other 95 on demand?
The connection between reversibility and efficiency traces back to thermodynamics. Richard Feynman showed that irreversible computation—operations that discard information—must dissipate energy. Reversible computation, where every step can be undone, approaches the thermodynamic limit.
The same principle applies to memory: if you can undo a computation, you don’t need to remember its inputs. RevNets and reversible transformers exploit this insight, trading compute for memory. The physics is the same; the resource has changed from energy to bytes.
12.1 The Property That Enables Forgetting
Some functions can be run backward:
\[y = f(x) \implies x = f^{-1}(y)\]
This is invertibility. Given the output, you can recover the input.
A special case is the involution—a function that is its own inverse:
\[f(f(x)) = x\]
XOR is an involution: a ^ b ^ b = a. Apply it twice, you’re back.
Invertibility is the license to:
- Forget: Don’t store inputs if you can reconstruct them from outputs
- Recompute: Trade compute for memory during backpropagation
- Stream: Process data without buffering intermediate states
- Undo: Reverse a transformation without tracking history
Without invertibility, you must remember everything. With it, you can forget strategically.
12.2 From ResNets to RevNets
The problem with deep networks is activation storage.
Standard backpropagation requires storing activations from the forward pass:
def forward_with_storage(layers, x):
activations = [x]
for layer in layers:
x = layer(x)
activations.append(x) # Must store for backward pass
return x, activationsFor L layers with activation size A, memory is O(L × A). A 96-layer transformer with 2048 hidden dimension and batch size 32 stores:
96 layers × 32 batch × 2048 hidden × 4 bytes = 25 GB
Just for activations. Before counting weights, optimizer states, or gradients.
12.2.1 The Reversible Block
RevNet (Gomez et al., 2017) introduces a simple architectural change that eliminates this scaling.
Split the input into two halves, then apply:
\[y_1 = x_1 + F(x_2)\] \[y_2 = x_2 + G(y_1)\]
where F and G are arbitrary differentiable functions (convolutions, MLPs, attention).
The key insight: these equations can be inverted:
\[x_2 = y_2 - G(y_1)\] \[x_1 = y_1 - F(x_2)\]
Given the output \((y_1, y_2)\), you can compute the input \((x_1, x_2)\).
def reversible_block_forward(x1, x2, F, G):
y1 = x1 + F(x2)
y2 = x2 + G(y1)
return y1, y2
def reversible_block_inverse(y1, y2, F, G):
x2 = y2 - G(y1)
x1 = y1 - F(x2)
return x1, x2Let’s verify this works:
import torch
import torch.nn as nn
# Simple F and G functions
F = nn.Sequential(nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, 256))
G = nn.Sequential(nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, 256))
# Forward
x1, x2 = torch.randn(32, 256), torch.randn(32, 256)
y1, y2 = reversible_block_forward(x1, x2, F, G)
# Inverse
x1_reconstructed, x2_reconstructed = reversible_block_inverse(y1, y2, F, G)
# Verify
print(f"Max error x1: {(x1 - x1_reconstructed).abs().max():.2e}") # ~1e-7
print(f"Max error x2: {(x2 - x2_reconstructed).abs().max():.2e}") # ~1e-7The reconstruction error is floating-point precision—the values are mathematically identical.
12.2.2 Memory Reduction
With reversible blocks, you only need to store the final layer’s activations:
Standard: O(L × A) 96 × 25MB = 2.4 GB per sample
Reversible: O(A) 25 MB per sample (independent of depth!)
The memory cost becomes independent of depth. You can train a 96-layer network with the same memory as a 1-layer network.
The trade-off: During backpropagation, you must recompute the activations by running the inverse. This adds ~33-50% compute overhead. But for memory-bound training, this trade-off is often worthwhile.
12.3 Investigation: Why Does Inversion Work?
Let’s derive why the reversible block equations are invertible.
Starting from: \[y_1 = x_1 + F(x_2)\] \[y_2 = x_2 + G(y_1)\]
Recover \(x_2\): From the second equation, solve for \(x_2\): \[x_2 = y_2 - G(y_1)\]
This works because we know \(y_1\) and \(y_2\) from the forward pass output.
Recover \(x_1\): Substitute \(x_2\) into the first equation: \[x_1 = y_1 - F(x_2) = y_1 - F(y_2 - G(y_1))\]
The structure is crucial: each equation involves only one unknown at a time.
12.3.1 What Makes It Invertible?
The additive coupling is key. Consider if we used multiplication:
\[y_1 = x_1 \cdot F(x_2)\]
To invert: \(x_1 = y_1 / F(x_2)\). But if \(F(x_2) = 0\), we can’t recover \(x_1\). Addition has no such singularity.
More generally, the pattern is:
\[y = x + f(\text{other terms})\]
Rearranges to:
\[x = y - f(\text{other terms})\]
The function \(f\) can be arbitrarily complex (even a transformer layer). As long as the combination is additive, inversion is straightforward.
12.4 Reformer: Reversible Transformers
The Reformer (Kitaev et al., 2020) applies reversibility to transformers.
The standard transformer layer:
x = x + Attention(LayerNorm(x))
x = x + FFN(LayerNorm(x))Stores activations at each layer for backprop.
The reversible transformer layer:
y1 = x1 + Attention(LayerNorm(x2))
y2 = x2 + FFN(LayerNorm(y1))Same computation, but now invertible.
Memory comparison for 12-layer transformer:
Standard: 12 × (attention activations + FFN activations)
Reversible: 1 × (attention activations + FFN activations)
For sequence length 4096, hidden 1024:
Standard: ~4 GB activations
Reversible: ~340 MB activations (12× reduction)
Combined with locality-sensitive hashing attention (reducing O(n²) to O(n log n)), Reformer handles sequences up to 64K tokens where standard transformers fail.
12.4.1 The Backward Pass
The reversible backward pass doesn’t use stored activations. Instead:
def reversible_backward(y1, y2, dy1, dy2, F, G):
# Reconstruct inputs from outputs
x2 = y2 - G(y1)
x1 = y1 - F(x2)
# Now compute gradients with respect to reconstructed activations
# (Standard backprop through F and G)
dx1, dx2, dF, dG = standard_backward(x1, x2, y1, y2, dy1, dy2, F, G)
return dx1, dx2, dF, dGThe activations are reconstructed on-the-fly, layer by layer, as backprop proceeds from output to input.
12.5 The Spectrum: From Checkpointing to Reversibility
Reversible networks are the extreme of a spectrum:
| Strategy | Memory | Compute Overhead | When to Use |
|---|---|---|---|
| Store all | O(L) | 0% | Memory abundant |
| Checkpoint every k | O(L/k) | (k-1)/k × 100% | Moderate memory |
| Checkpoint per layer | O(√L) | ~100% | Memory constrained |
| Reversible | O(1) | ~33-50% | Memory critical |
Gradient checkpointing (Chen et al., 2016) is the middle ground: store some activations, recompute others.
# Standard: store everything
y = layer3(layer2(layer1(x))) # Stores x, layer1(x), layer2(x)
# Checkpointed: store only at checkpoints
y = checkpoint(layer3,
checkpoint(layer2,
checkpoint(layer1, x))) # Stores only x; recomputes on backwardThe optimal checkpoint strategy depends on the compute-memory ratio of your hardware. For modern GPUs with high FLOPS but limited HBM, aggressive checkpointing often wins.
12.6 Beyond Training: Normalizing Flows
Invertibility appears in generative modeling too.
Normalizing flows are generative models built from invertible transformations:
\[z \sim p(z) \quad \text{(simple prior, e.g., Gaussian)}\] \[x = f(z) \quad \text{(invertible transformation)}\]
Because \(f\) is invertible, we can compute exact likelihoods:
\[p(x) = p(z) \cdot \left|\det \frac{\partial f^{-1}}{\partial x}\right|\]
Real NVP, Glow, and other flow models use the same additive coupling idea as RevNets:
\[y_1 = x_1\] \[y_2 = x_2 \cdot \exp(s(x_1)) + t(x_1)\]
This is invertible (given \(y_1 = x_1\), solve for \(x_2\)), and the Jacobian determinant is tractable.
The connection: Both RevNets and normalizing flows exploit invertibility, but for different purposes:
- RevNets: Memory efficiency (reconstruct activations)
- Flows: Density estimation (compute likelihoods)
The algebra enables both.
12.7 When Invertibility Breaks
Not all operations are invertible. Recognize these cases:
ReLU: \(\text{ReLU}(x) = \max(0, x)\)
If the output is 0, the input could be any negative number. Information is lost.
Pooling: Max pooling discards which element was maximum. Average pooling loses individual values.
Stride > 1: Downsampling discards spatial information.
Attention with softmax: The normalized weights don’t preserve the raw logits.
For these operations, reversible networks use workarounds:
# Instead of ReLU, use invertible alternatives
def leaky_relu_inverse(y, negative_slope=0.01):
return torch.where(y >= 0, y, y / negative_slope)
# Instead of max pooling, use invertible downsampling
# (squeeze-and-excitation, strided convolution with stored indices)The general principle: some information loss is unavoidable (that’s what makes a network a compression), but structure the loss to occur at explicit points where you can store the discarded information cheaply.
12.8 The Compute-Memory Trade-off
Reversibility trades compute for memory:
┌─────────────────────────────────────────────────────────────────┐
│ │
│ Memory │
│ ▲ │
│ │ │
│ │ ● Store All │
│ │ (0% overhead) │
│ │ │
│ │ ● Checkpoint/k │
│ │ (~50% overhead) │
│ │ │
│ │ ● Checkpoint/√L │
│ │ (~100% overhead) │
│ │ │
│ │ ● Reversible │
│ │ (~50% overhead, O(1) memory) │
│ │ │
│ └──────────────────────────────────────────────────► │
│ Compute │
│ │
└─────────────────────────────────────────────────────────────────┘
When is the trade-off worth it?
Batch size limited by memory: Larger batches often train faster. If memory is the bottleneck, reversibility enables larger batches.
Model depth limited by memory: Deeper networks may perform better. Reversibility enables depth scaling.
Long sequences: For transformers, activation memory scales with sequence length. Reversibility enables longer contexts.
Hardware with high FLOPS/byte ratio: Modern GPUs compute faster than they can move data. Extra compute for reconstruction may overlap with memory operations.
12.9 The Hardware Connection
Reversibility interacts with hardware constraints:
Memory Hierarchy (Chapter 1): Reversible blocks need to recompute activations. This recomputation should ideally hit cache, not main memory.
Bandwidth (Chapter 2): The backward pass of reversible networks streams through activations once (reconstructing), rather than reading stored activations. This can reduce memory traffic.
Parallelism (Chapter 3): Reconstruction is sequential through layers—you must reconstruct layer L-1 before layer L-2. This limits parallelism in the backward pass.
Fusion (Chapter 8): Reversible blocks benefit from operator fusion. Fusing F and G’s operations reduces memory traffic during reconstruction.
12.10 Key Takeaways
Invertibility is a license to forget: If you can reconstruct inputs from outputs, you don’t need to store them
Additive coupling is the key structure: \(y = x + f(\text{other})\) is trivially invertible regardless of f’s complexity
The trade-off is compute for memory: ~33-50% more compute, but O(1) memory in depth
It’s a spectrum: Full storage → checkpointing → reversibility. Choose based on your memory-compute ratio
Same algebra, multiple applications: RevNets (training), normalizing flows (generation), and even incremental hashing (caching) all exploit invertibility
12.11 Exercises
Exercise 1: Verify Reversibility
Implement a reversible block and verify that inverse(forward(x)) == x to floating-point precision.
What happens if you use float16 instead of float32? How does precision loss accumulate over many layers?
Implement a 10-layer reversible network and measure reconstruction error at each layer.
Exercise 2: Memory Measurement
Compare memory usage between standard and reversible residual blocks:
- Implement both versions of a 20-layer network
- Measure peak memory usage during forward + backward
- How does the ratio change with batch size?
Exercise 3: Non-Additive Coupling
The multiplicative coupling \(y = x \cdot f(\text{other})\) is used in normalizing flows but not RevNets. Why?
- What condition on \(f\) is required for invertibility?
- Implement multiplicative coupling with a safeguard against division by zero
- When would multiplicative coupling be preferred?
Exercise 4: Checkpointing Strategy
For a 64-layer network with 1GB per layer:
- How much memory does full storage require?
- What’s the optimal checkpointing interval if you have 16GB?
- How does compute overhead compare to reversible layers?
12.12 Further Reading
- Gomez et al. (2017). “The Reversible Residual Network: Backpropagation Without Storing Activations” - The original RevNet paper
- Kitaev et al. (2020). “Reformer: The Efficient Transformer” - Reversible attention for long sequences
- Chen et al. (2016). “Training Deep Nets with Sublinear Memory Cost” - Gradient checkpointing
- Dinh et al. (2017). “Density Estimation Using Real-NVP” - Invertible generative models