12  Reversibility

How Invertibility Enables Memory-Efficient Training

You’re training a 96-layer transformer. Each layer stores activations for the backward pass.

That’s 96 copies of intermediate state, each the size of your batch times hidden dimension. On a 40GB A100, you run out of memory before you run out of model.

What if you could train with just one copy—and reconstruct the other 95 on demand?

TipHistorical Note: Feynman’s Reversible Computation (1985)

The connection between reversibility and efficiency traces back to thermodynamics. Richard Feynman showed that irreversible computation—operations that discard information—must dissipate energy. Reversible computation, where every step can be undone, approaches the thermodynamic limit.

The same principle applies to memory: if you can undo a computation, you don’t need to remember its inputs. RevNets and reversible transformers exploit this insight, trading compute for memory. The physics is the same; the resource has changed from energy to bytes.

12.1 The Property That Enables Forgetting

Some functions can be run backward:

\[y = f(x) \implies x = f^{-1}(y)\]

This is invertibility. Given the output, you can recover the input.

A special case is the involution—a function that is its own inverse:

\[f(f(x)) = x\]

XOR is an involution: a ^ b ^ b = a. Apply it twice, you’re back.

Invertibility is the license to:

  • Forget: Don’t store inputs if you can reconstruct them from outputs
  • Recompute: Trade compute for memory during backpropagation
  • Stream: Process data without buffering intermediate states
  • Undo: Reverse a transformation without tracking history

Without invertibility, you must remember everything. With it, you can forget strategically.

12.2 From ResNets to RevNets

The problem with deep networks is activation storage.

Standard backpropagation requires storing activations from the forward pass:

def forward_with_storage(layers, x):
    activations = [x]
    for layer in layers:
        x = layer(x)
        activations.append(x)  # Must store for backward pass
    return x, activations

For L layers with activation size A, memory is O(L × A). A 96-layer transformer with 2048 hidden dimension and batch size 32 stores:

96 layers × 32 batch × 2048 hidden × 4 bytes = 25 GB

Just for activations. Before counting weights, optimizer states, or gradients.

12.2.1 The Reversible Block

RevNet (Gomez et al., 2017) introduces a simple architectural change that eliminates this scaling.

Split the input into two halves, then apply:

\[y_1 = x_1 + F(x_2)\] \[y_2 = x_2 + G(y_1)\]

where F and G are arbitrary differentiable functions (convolutions, MLPs, attention).

The key insight: these equations can be inverted:

\[x_2 = y_2 - G(y_1)\] \[x_1 = y_1 - F(x_2)\]

Given the output \((y_1, y_2)\), you can compute the input \((x_1, x_2)\).

def reversible_block_forward(x1, x2, F, G):
    y1 = x1 + F(x2)
    y2 = x2 + G(y1)
    return y1, y2

def reversible_block_inverse(y1, y2, F, G):
    x2 = y2 - G(y1)
    x1 = y1 - F(x2)
    return x1, x2

Let’s verify this works:

import torch
import torch.nn as nn

# Simple F and G functions
F = nn.Sequential(nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, 256))
G = nn.Sequential(nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, 256))

# Forward
x1, x2 = torch.randn(32, 256), torch.randn(32, 256)
y1, y2 = reversible_block_forward(x1, x2, F, G)

# Inverse
x1_reconstructed, x2_reconstructed = reversible_block_inverse(y1, y2, F, G)

# Verify
print(f"Max error x1: {(x1 - x1_reconstructed).abs().max():.2e}")  # ~1e-7
print(f"Max error x2: {(x2 - x2_reconstructed).abs().max():.2e}")  # ~1e-7

The reconstruction error is floating-point precision—the values are mathematically identical.

12.2.2 Memory Reduction

With reversible blocks, you only need to store the final layer’s activations:

Standard:    O(L × A)   96 × 25MB = 2.4 GB per sample
Reversible:  O(A)       25 MB per sample (independent of depth!)

The memory cost becomes independent of depth. You can train a 96-layer network with the same memory as a 1-layer network.

The trade-off: During backpropagation, you must recompute the activations by running the inverse. This adds ~33-50% compute overhead. But for memory-bound training, this trade-off is often worthwhile.

12.3 Investigation: Why Does Inversion Work?

Let’s derive why the reversible block equations are invertible.

Starting from: \[y_1 = x_1 + F(x_2)\] \[y_2 = x_2 + G(y_1)\]

Recover \(x_2\): From the second equation, solve for \(x_2\): \[x_2 = y_2 - G(y_1)\]

This works because we know \(y_1\) and \(y_2\) from the forward pass output.

Recover \(x_1\): Substitute \(x_2\) into the first equation: \[x_1 = y_1 - F(x_2) = y_1 - F(y_2 - G(y_1))\]

The structure is crucial: each equation involves only one unknown at a time.

12.3.1 What Makes It Invertible?

The additive coupling is key. Consider if we used multiplication:

\[y_1 = x_1 \cdot F(x_2)\]

To invert: \(x_1 = y_1 / F(x_2)\). But if \(F(x_2) = 0\), we can’t recover \(x_1\). Addition has no such singularity.

More generally, the pattern is:

\[y = x + f(\text{other terms})\]

Rearranges to:

\[x = y - f(\text{other terms})\]

The function \(f\) can be arbitrarily complex (even a transformer layer). As long as the combination is additive, inversion is straightforward.

12.4 Reformer: Reversible Transformers

The Reformer (Kitaev et al., 2020) applies reversibility to transformers.

The standard transformer layer:

x = x + Attention(LayerNorm(x))
x = x + FFN(LayerNorm(x))

Stores activations at each layer for backprop.

The reversible transformer layer:

y1 = x1 + Attention(LayerNorm(x2))
y2 = x2 + FFN(LayerNorm(y1))

Same computation, but now invertible.

Memory comparison for 12-layer transformer:

Standard:    12 × (attention activations + FFN activations)
Reversible:  1 × (attention activations + FFN activations)

For sequence length 4096, hidden 1024:
Standard:    ~4 GB activations
Reversible:  ~340 MB activations (12× reduction)

Combined with locality-sensitive hashing attention (reducing O(n²) to O(n log n)), Reformer handles sequences up to 64K tokens where standard transformers fail.

12.4.1 The Backward Pass

The reversible backward pass doesn’t use stored activations. Instead:

def reversible_backward(y1, y2, dy1, dy2, F, G):
    # Reconstruct inputs from outputs
    x2 = y2 - G(y1)
    x1 = y1 - F(x2)

    # Now compute gradients with respect to reconstructed activations
    # (Standard backprop through F and G)
    dx1, dx2, dF, dG = standard_backward(x1, x2, y1, y2, dy1, dy2, F, G)

    return dx1, dx2, dF, dG

The activations are reconstructed on-the-fly, layer by layer, as backprop proceeds from output to input.

12.5 The Spectrum: From Checkpointing to Reversibility

Reversible networks are the extreme of a spectrum:

Strategy Memory Compute Overhead When to Use
Store all O(L) 0% Memory abundant
Checkpoint every k O(L/k) (k-1)/k × 100% Moderate memory
Checkpoint per layer O(√L) ~100% Memory constrained
Reversible O(1) ~33-50% Memory critical

Gradient checkpointing (Chen et al., 2016) is the middle ground: store some activations, recompute others.

# Standard: store everything
y = layer3(layer2(layer1(x)))  # Stores x, layer1(x), layer2(x)

# Checkpointed: store only at checkpoints
y = checkpoint(layer3,
    checkpoint(layer2,
        checkpoint(layer1, x)))  # Stores only x; recomputes on backward

The optimal checkpoint strategy depends on the compute-memory ratio of your hardware. For modern GPUs with high FLOPS but limited HBM, aggressive checkpointing often wins.

12.6 Beyond Training: Normalizing Flows

Invertibility appears in generative modeling too.

Normalizing flows are generative models built from invertible transformations:

\[z \sim p(z) \quad \text{(simple prior, e.g., Gaussian)}\] \[x = f(z) \quad \text{(invertible transformation)}\]

Because \(f\) is invertible, we can compute exact likelihoods:

\[p(x) = p(z) \cdot \left|\det \frac{\partial f^{-1}}{\partial x}\right|\]

Real NVP, Glow, and other flow models use the same additive coupling idea as RevNets:

\[y_1 = x_1\] \[y_2 = x_2 \cdot \exp(s(x_1)) + t(x_1)\]

This is invertible (given \(y_1 = x_1\), solve for \(x_2\)), and the Jacobian determinant is tractable.

The connection: Both RevNets and normalizing flows exploit invertibility, but for different purposes:

  • RevNets: Memory efficiency (reconstruct activations)
  • Flows: Density estimation (compute likelihoods)

The algebra enables both.

12.7 When Invertibility Breaks

Not all operations are invertible. Recognize these cases:

ReLU: \(\text{ReLU}(x) = \max(0, x)\)

If the output is 0, the input could be any negative number. Information is lost.

Pooling: Max pooling discards which element was maximum. Average pooling loses individual values.

Stride > 1: Downsampling discards spatial information.

Attention with softmax: The normalized weights don’t preserve the raw logits.

For these operations, reversible networks use workarounds:

# Instead of ReLU, use invertible alternatives
def leaky_relu_inverse(y, negative_slope=0.01):
    return torch.where(y >= 0, y, y / negative_slope)

# Instead of max pooling, use invertible downsampling
# (squeeze-and-excitation, strided convolution with stored indices)

The general principle: some information loss is unavoidable (that’s what makes a network a compression), but structure the loss to occur at explicit points where you can store the discarded information cheaply.

12.8 The Compute-Memory Trade-off

Reversibility trades compute for memory:

┌─────────────────────────────────────────────────────────────────┐
│                                                                 │
│   Memory                                                        │
│     ▲                                                          │
│     │                                                          │
│     │  ● Store All                                             │
│     │     (0% overhead)                                        │
│     │                                                          │
│     │        ● Checkpoint/k                                    │
│     │           (~50% overhead)                                │
│     │                                                          │
│     │              ● Checkpoint/√L                             │
│     │                 (~100% overhead)                         │
│     │                                                          │
│     │                    ● Reversible                          │
│     │                       (~50% overhead, O(1) memory)       │
│     │                                                          │
│     └──────────────────────────────────────────────────►       │
│                                             Compute             │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

When is the trade-off worth it?

  1. Batch size limited by memory: Larger batches often train faster. If memory is the bottleneck, reversibility enables larger batches.

  2. Model depth limited by memory: Deeper networks may perform better. Reversibility enables depth scaling.

  3. Long sequences: For transformers, activation memory scales with sequence length. Reversibility enables longer contexts.

  4. Hardware with high FLOPS/byte ratio: Modern GPUs compute faster than they can move data. Extra compute for reconstruction may overlap with memory operations.

12.9 The Hardware Connection

Reversibility interacts with hardware constraints:

Memory Hierarchy (Chapter 1): Reversible blocks need to recompute activations. This recomputation should ideally hit cache, not main memory.

Bandwidth (Chapter 2): The backward pass of reversible networks streams through activations once (reconstructing), rather than reading stored activations. This can reduce memory traffic.

Parallelism (Chapter 3): Reconstruction is sequential through layers—you must reconstruct layer L-1 before layer L-2. This limits parallelism in the backward pass.

Fusion (Chapter 8): Reversible blocks benefit from operator fusion. Fusing F and G’s operations reduces memory traffic during reconstruction.

12.10 Key Takeaways

  1. Invertibility is a license to forget: If you can reconstruct inputs from outputs, you don’t need to store them

  2. Additive coupling is the key structure: \(y = x + f(\text{other})\) is trivially invertible regardless of f’s complexity

  3. The trade-off is compute for memory: ~33-50% more compute, but O(1) memory in depth

  4. It’s a spectrum: Full storage → checkpointing → reversibility. Choose based on your memory-compute ratio

  5. Same algebra, multiple applications: RevNets (training), normalizing flows (generation), and even incremental hashing (caching) all exploit invertibility

12.11 Exercises

Exercise 1: Verify Reversibility

Implement a reversible block and verify that inverse(forward(x)) == x to floating-point precision.

  1. What happens if you use float16 instead of float32? How does precision loss accumulate over many layers?

  2. Implement a 10-layer reversible network and measure reconstruction error at each layer.

Exercise 2: Memory Measurement

Compare memory usage between standard and reversible residual blocks:

  1. Implement both versions of a 20-layer network
  2. Measure peak memory usage during forward + backward
  3. How does the ratio change with batch size?

Exercise 3: Non-Additive Coupling

The multiplicative coupling \(y = x \cdot f(\text{other})\) is used in normalizing flows but not RevNets. Why?

  1. What condition on \(f\) is required for invertibility?
  2. Implement multiplicative coupling with a safeguard against division by zero
  3. When would multiplicative coupling be preferred?

Exercise 4: Checkpointing Strategy

For a 64-layer network with 1GB per layer:

  1. How much memory does full storage require?
  2. What’s the optimal checkpointing interval if you have 16GB?
  3. How does compute overhead compare to reversible layers?
NoteTry It Yourself

The accompanying notebook lets you:

  • Implement and test reversible blocks
  • Measure memory savings vs. compute overhead
  • Compare checkpointing strategies
  • Build a mini reversible transformer

Open In Colab

12.12 Further Reading

  • Gomez et al. (2017). “The Reversible Residual Network: Backpropagation Without Storing Activations” - The original RevNet paper
  • Kitaev et al. (2020). “Reformer: The Efficient Transformer” - Reversible attention for long sequences
  • Chen et al. (2016). “Training Deep Nets with Sublinear Memory Cost” - Gradient checkpointing
  • Dinh et al. (2017). “Density Estimation Using Real-NVP” - Invertible generative models