Skip to content

Activation Recomputation

Activations dominate memory in deep networks. A 7B model's parameters need ~14 GB, but its activations can consume 100+ GB. Activation recomputation—also called gradient checkpointing—trades compute for memory by recomputing activations during the backward pass instead of storing them.

The Question: The backward pass needs activations from the forward pass. Storing them all requires O(L) memory for L layers. Can we reduce this to O(√L) or even O(1)? What's the compute cost of this trade-off?

Chapter Map

Prerequisites: Chapter 19 (memory equation breakdown)

Key insight: Activations dominate memory for large batch sizes. By checkpointing only at layer boundaries and recomputing intermediates during backward, you trade ~33% extra compute for dramatic memory savings—enabling larger batches and longer sequences.

The Activation Memory Problem

During backpropagation, computing gradients requires activations from the forward pass:

\[\frac{\partial L}{\partial W_l} = \frac{\partial L}{\partial y_l} \cdot \frac{\partial y_l}{\partial W_l} = \delta_l \cdot a_{l-1}^T\]

Where:

  • \(\delta_l = \partial L / \partial y_l\): the error signal
  • \(a_{l-1}\): the input activation (from forward pass)

Without stored activations, we cannot compute gradients.

Memory Growth

For a transformer with \(L\) layers, the stored activations include:

Component Shape Per Layer
Layer input \((B, S, H)\) \(2BSH\) bytes
Attention Q, K, V \((B, S, H) \times 3\) \(6BSH\) bytes
Attention scores \((B, \text{heads}, S, S)\) \(2BnS^2\) bytes
Softmax output \((B, \text{heads}, S, S)\) \(2BnS^2\) bytes
Attention output \((B, S, H)\) \(2BSH\) bytes
FFN intermediate \((B, S, 4H)\) \(8BSH\) bytes
Various intermediates - ~\(2BSH\) bytes

Total per layer: approximately \(20BSH + 4BnS^2\) bytes (FP16).

For a 7B model (\(L=32\), \(H=4096\), \(n=32\)) with \(B=4\), \(S=2048\):

  • Per layer: \(20 \times 4 \times 2048 \times 4096 \times 2 + 4 \times 4 \times 32 \times 2048^2 \times 2\)
  • Per layer: \(1.34\) GB + \(4.29\) GB = \(5.63\) GB
  • Total: \(32 \times 5.63\) GB = 180 GB

This far exceeds GPU memory.

The Recomputation Trade-off

Key insight: We don't need to store activations if we can recompute them.

Basic Idea

Standard forward/backward:
Forward:  Layer 0 → save a0 → Layer 1 → save a1 → ... → Layer L → Loss
Backward: Load aL-1 → grad L → Load aL-2 → grad L-1 → ... → grad 0

With recomputation:
Forward:  Layer 0 → Layer 1 → ... → Layer L → Loss (save only checkpoints)
Backward: Recompute aL-1 → grad L → Recompute aL-2 → grad L-1 → ...

The Fundamental Trade-off

Let:

  • \(M\): memory for activations
  • \(C\): compute for forward passes
  • \(K\): number of checkpoints
Strategy Memory Forward Passes
Store all \(O(L)\) \(1\)
Store none \(O(1)\) \(\approx 2\) (one extra forward)
Checkpoint every \(\sqrt{L}\) \(O(\sqrt{L})\) \(\approx 2\)

The optimal checkpoint interval minimizes total cost.

Checkpointing Strategies

Strategy 1: Full Recomputation (No Storage)

Store only the input and recompute everything during backward.

class FullRecomputeFunction(torch.autograd.Function):
    """Recompute entire forward pass during backward."""

    @staticmethod
    def forward(ctx, input: torch.Tensor, layers: nn.ModuleList):
        ctx.layers = layers

        # Forward through all layers (no intermediate storage)
        output = input
        for layer in layers:
            output = layer(output)

        # Save only input for recomputation
        ctx.save_for_backward(input)
        return output

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        input, = ctx.saved_tensors
        layers = ctx.layers

        # Recompute forward pass
        activations = [input]
        hidden = input
        for layer in layers:
            hidden = layer(hidden)
            activations.append(hidden)

        # Standard backward
        grad = grad_output
        grads = []
        for i in range(len(layers) - 1, -1, -1):
            with torch.enable_grad():
                act = activations[i].detach().requires_grad_(True)
                out = layers[i](act)
                grad = torch.autograd.grad(out, act, grad)[0]

        return grad, None

Memory: \(O(1)\)—only input and current layer. Compute: ~1 extra forward pass (≈33% overhead in practice). Problem: Too expensive for large \(L\).

Strategy 2: Uniform Checkpointing

Save activation every \(k\) layers. Recompute only within segments.

def uniform_checkpoint_forward(input: torch.Tensor,
                               layers: nn.ModuleList,
                               checkpoint_interval: int) -> torch.Tensor:
    """Forward with uniform checkpointing."""
    checkpoints = [input]
    hidden = input

    for i, layer in enumerate(layers):
        hidden = layer(hidden)
        if (i + 1) % checkpoint_interval == 0:
            checkpoints.append(hidden)

    return hidden, checkpoints

def uniform_checkpoint_backward(grad_output: torch.Tensor,
                                layers: nn.ModuleList,
                                checkpoints: List[torch.Tensor],
                                checkpoint_interval: int):
    """Backward with segment recomputation."""
    num_segments = len(checkpoints)
    grad = grad_output

    for seg_idx in range(num_segments - 1, -1, -1):
        # Determine segment boundaries
        start_layer = seg_idx * checkpoint_interval
        end_layer = min((seg_idx + 1) * checkpoint_interval, len(layers))

        # Recompute segment from checkpoint
        checkpoint = checkpoints[seg_idx]
        activations = [checkpoint]
        hidden = checkpoint
        for i in range(start_layer, end_layer):
            with torch.enable_grad():
                hidden = layers[i](hidden)
                activations.append(hidden)

        # Backward through segment
        for i in range(end_layer - 1, start_layer - 1, -1):
            layer_input = activations[i - start_layer]
            layer_output = activations[i - start_layer + 1]

            # Compute gradients for this layer
            grad = torch.autograd.grad(
                layer_output, layer_input, grad,
                retain_graph=True
            )[0]

    return grad

Analysis for interval \(k\):

  • Number of checkpoints: \(L/k\)
  • Memory per checkpoint: \(M_{\text{layer}}\)
  • Peak memory within segment: \(k \cdot M_{\text{layer}}\)

Total memory: \(\frac{L}{k} \cdot M_{\text{layer}} + k \cdot M_{\text{layer}}\)

Minimized when \(L/k = k\), i.e., \(k = \sqrt{L}\).

Optimal checkpoint interval: \(k^* = \sqrt{L}\)

Optimal memory: \(2\sqrt{L} \cdot M_{\text{layer}}\)

For \(L = 32\) layers: \(k^* \approx 6\), memory reduced from \(32M\) to \(12M\) (2.7× savings).

Strategy 3: Selective Checkpointing

Not all layers consume equal memory. Checkpoint strategically.

def analyze_layer_memory(layer: nn.Module,
                         input_shape: Tuple[int, ...]) -> int:
    """Estimate activation memory for a layer."""
    # Hook to capture activation sizes
    activation_sizes = []

    def hook(module, input, output):
        if isinstance(output, torch.Tensor):
            activation_sizes.append(output.numel() * output.element_size())
        elif isinstance(output, tuple):
            for o in output:
                if isinstance(o, torch.Tensor):
                    activation_sizes.append(o.numel() * o.element_size())

    handles = []
    for module in layer.modules():
        handles.append(module.register_forward_hook(hook))

    # Dummy forward
    dummy_input = torch.randn(input_shape)
    with torch.no_grad():
        layer(dummy_input)

    # Clean up hooks
    for h in handles:
        h.remove()

    return sum(activation_sizes)

def select_checkpoints(layers: nn.ModuleList,
                       memory_budget: int,
                       input_shape: Tuple[int, ...]) -> List[int]:
    """Select optimal checkpoint locations given memory budget."""
    layer_memories = [
        analyze_layer_memory(layer, input_shape)
        for layer in layers
    ]

    # Greedy selection: checkpoint before high-memory layers
    total_memory = sum(layer_memories)
    num_checkpoints = total_memory // memory_budget

    # Sort layers by memory, checkpoint before largest
    sorted_indices = sorted(range(len(layers)),
                           key=lambda i: layer_memories[i],
                           reverse=True)

    checkpoint_indices = sorted(sorted_indices[:num_checkpoints])
    return checkpoint_indices

PyTorch Checkpoint API

PyTorch provides built-in checkpointing support.

Basic Usage

from torch.utils.checkpoint import checkpoint, checkpoint_sequential

class TransformerWithCheckpointing(nn.Module):
    """Transformer with activation checkpointing."""

    def __init__(self, config):
        super().__init__()
        self.embedding = nn.Embedding(config.vocab_size, config.hidden_dim)
        self.layers = nn.ModuleList([
            TransformerLayer(config) for _ in range(config.num_layers)
        ])
        self.head = nn.Linear(config.hidden_dim, config.vocab_size)

        self.use_checkpointing = config.use_checkpointing
        self.checkpoint_ratio = config.checkpoint_ratio  # e.g., 0.5

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        hidden = self.embedding(input_ids)

        num_checkpointed = int(len(self.layers) * self.checkpoint_ratio)

        for i, layer in enumerate(self.layers):
            if self.use_checkpointing and i < num_checkpointed:
                # Checkpoint this layer
                hidden = checkpoint(
                    layer,
                    hidden,
                    use_reentrant=False,
                    preserve_rng_state=True
                )
            else:
                # Normal forward
                hidden = layer(hidden)

        return self.head(hidden)

Checkpoint Sequential

For sequential models, use checkpoint_sequential:

def forward_with_sequential_checkpoint(self, x: torch.Tensor) -> torch.Tensor:
    """Forward using checkpoint_sequential for uniform segments."""
    # Divide layers into segments
    num_segments = int(math.sqrt(len(self.layers)))

    # checkpoint_sequential handles segment boundaries automatically
    hidden = checkpoint_sequential(
        self.layers,
        num_segments,
        x,
        use_reentrant=False
    )

    return hidden

Non-Reentrant Checkpointing

PyTorch offers two checkpointing modes:

Reentrant (legacy):

  • Uses torch.autograd.grad internally
  • Can have subtle bugs with certain operations
  • Being deprecated

Non-Reentrant (recommended):

  • Uses saved tensor hooks
  • More robust with complex graphs
  • Preserves RNG state correctly
# Always prefer non-reentrant
hidden = checkpoint(
    layer,
    hidden,
    use_reentrant=False,  # Recommended
    preserve_rng_state=True  # Important for dropout
)

Strategy 4: Attention-Specific Checkpointing

Attention has unique memory patterns. The \(O(S^2)\) attention scores dominate.

class CheckpointedAttention(nn.Module):
    """
    Attention with selective recomputation.

    Stores Q, K, V but recomputes attention scores.
    Saves O(S²) memory per head.
    """

    def __init__(self, hidden_dim: int, num_heads: int):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(hidden_dim, 3 * hidden_dim)
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, S, H = x.shape

        # Compute Q, K, V (stored)
        qkv = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)

        # Reshape for multi-head
        q = q.view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, S, self.num_heads, self.head_dim).transpose(1, 2)

        # Use checkpointing for attention computation
        attn_output = torch.utils.checkpoint.checkpoint(
            self._attention_forward,
            q, k, v,
            use_reentrant=False
        )

        # Output projection
        output = attn_output.transpose(1, 2).reshape(B, S, H)
        return self.out_proj(output)

    def _attention_forward(self,
                          q: torch.Tensor,
                          k: torch.Tensor,
                          v: torch.Tensor) -> torch.Tensor:
        """Core attention computation (recomputed in backward)."""
        # Attention scores: O(S²) memory NOT stored
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        attn_weights = F.softmax(scores, dim=-1)
        return torch.matmul(attn_weights, v)

Memory savings:

  • Without checkpointing: \(2 \times B \times n \times S^2\) bytes for scores + softmax
  • With checkpointing: Only Q, K, V stored (\(6BSH\) bytes)

For \(S=2048\), \(n=32\): saves \(4 \times 32 \times 2048^2 \times 2 = 1.07\) GB per layer.

The Compute Cost Analysis

Single Checkpoint

Checkpointing layer \(l\) means:

  • Forward: compute layer \(l\) once
  • Backward: recompute layer \(l\) once before computing gradients

Overhead: One extra forward pass per checkpointed layer.

Full Model Analysis

Let:

  • \(F\): FLOPs for one forward pass
  • \(B\): FLOPs for one backward pass (\(B \approx 2F\) typically)
  • \(c\): fraction of layers checkpointed

Without checkpointing:

\[C_{\text{total}} = F + B = F + 2F = 3F\]

With checkpointing (fraction \(c\)):

\[C_{\text{total}} = F + cF + B = F(1 + c + 2) = F(3 + c)\]

Compute overhead: \(c \cdot F\), or \(\frac{c}{3}\) relative increase.

Checkpoint Fraction Relative Overhead
0% 0%
50% 16.7%
100% 33.3%

Maximum overhead is 33% even with full checkpointing.

Memory-Compute Pareto Frontier

Memory
  │●  No checkpointing (3F compute, L memory)
  │    ●  50% checkpointing (3.5F, 0.5L + √L)
  │        ●  √L checkpointing (3.33F, 2√L)
  │            ●  Full checkpointing (4F, O(1))
  └─────────────────────────────────────────→ Compute

Trade-off: 33% more compute for ~√L memory reduction

Advanced Techniques

Activation Compression

Instead of discarding activations, compress them.

class CompressedCheckpoint(torch.autograd.Function):
    """Checkpoint with lossy activation compression."""

    @staticmethod
    def forward(ctx, input: torch.Tensor, layer: nn.Module, compression: str):
        ctx.layer = layer
        ctx.compression = compression

        # Compress input for storage
        if compression == 'fp8':
            compressed = input.to(torch.float8_e4m3fn)
        elif compression == 'quantize':
            # 8-bit quantization
            scale = input.abs().max() / 127
            compressed = (input / scale).round().to(torch.int8)
            ctx.scale = scale
        else:
            compressed = input

        ctx.save_for_backward(compressed)

        # Forward with full precision
        return layer(input)

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        compressed, = ctx.saved_tensors
        layer = ctx.layer

        # Decompress
        if ctx.compression == 'fp8':
            input_approx = compressed.to(torch.float16)
        elif ctx.compression == 'quantize':
            input_approx = compressed.float() * ctx.scale
        else:
            input_approx = compressed

        # Compute gradients with approximate activations
        input_approx.requires_grad_(True)
        with torch.enable_grad():
            output = layer(input_approx)
            grad_input = torch.autograd.grad(output, input_approx, grad_output)[0]

        return grad_input, None, None

Trade-off: Some gradient accuracy for significant memory savings (4× for FP8, 2× for INT8).

Offloaded Checkpointing

Combine checkpointing with CPU offloading.

class OffloadedCheckpoint(torch.autograd.Function):
    """Checkpoint with CPU offloading."""

    @staticmethod
    def forward(ctx, input: torch.Tensor, layer: nn.Module):
        ctx.layer = layer
        ctx.input_shape = input.shape
        ctx.input_dtype = input.dtype
        ctx.input_device = input.device

        # Offload input to CPU (async)
        ctx.input_cpu = input.to('cpu', non_blocking=True)

        return layer(input)

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        # Prefetch input from CPU
        input_gpu = ctx.input_cpu.to(ctx.input_device, non_blocking=True)
        torch.cuda.synchronize()

        input_gpu.requires_grad_(True)
        with torch.enable_grad():
            output = ctx.layer(input_gpu)
            grad_input = torch.autograd.grad(output, input_gpu, grad_output)[0]

        return grad_input, None

Selective Recomputation

Some operations are cheap to store, expensive to recompute. Be selective:

class SelectiveRecompute(nn.Module):
    """Selectively recompute expensive operations."""

    def __init__(self, layer: nn.Module):
        super().__init__()
        self.layer = layer

        # Operations to always store (cheap to store, expensive to compute)
        self.store_ops = {'embedding', 'layernorm', 'linear'}

        # Operations to recompute (expensive to store, cheap to compute)
        self.recompute_ops = {'attention_scores', 'softmax', 'dropout'}

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Custom forward with selective storage
        return selective_checkpoint(self.layer, x, self.recompute_ops)

Interaction with Other Memory Optimizations

Checkpointing + ZeRO

These are complementary:

  • ZeRO reduces model state memory (parameters, gradients, optimizer)
  • Checkpointing reduces activation memory
class ZeROWithCheckpointing:
    """Combined ZeRO and checkpointing."""

    def __init__(self, model: nn.Module, zero_stage: int):
        self.model = model
        self.zero_optimizer = ZeROOptimizer(model, stage=zero_stage)

        # Enable checkpointing for all transformer layers
        for layer in model.layers:
            layer.use_checkpoint = True

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        # ZeRO-3: parameters gathered on-demand
        # Checkpointing: activations recomputed on-demand
        return self.model(input)

Combined memory:

  • Model state: \(16\Psi/P\) (ZeRO-3)
  • Activations: \(O(\sqrt{L})\) (checkpointing)

Checkpointing + Tensor Parallelism

When using TP, checkpoint carefully to avoid redundant communication:

class TPCheckpointedLayer(nn.Module):
    """Tensor-parallel layer with smart checkpointing."""

    def __init__(self, hidden_dim: int, tp_degree: int):
        super().__init__()
        self.tp_degree = tp_degree
        self.local_hidden = hidden_dim // tp_degree

        self.linear = ColumnParallelLinear(hidden_dim, 4 * hidden_dim)
        self.output = RowParallelLinear(4 * hidden_dim, hidden_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Checkpoint the compute-heavy middle section
        # Don't checkpoint the AllReduce (would serialize communication)
        intermediate = checkpoint(
            self.linear,
            x,
            use_reentrant=False
        )

        # AllReduce happens here (in RowParallel backward)
        return self.output(F.gelu(intermediate))

Checkpointing + Pipeline Parallelism

Pipeline stages naturally create checkpoints at stage boundaries:

class PipelineStage(nn.Module):
    """Pipeline stage with internal checkpointing."""

    def __init__(self, layers: nn.ModuleList, checkpoint_internal: bool = True):
        super().__init__()
        self.layers = layers
        self.checkpoint_internal = checkpoint_internal

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Stage input is always stored (for pipeline backward)

        if self.checkpoint_internal and len(self.layers) > 1:
            # Checkpoint within stage
            num_segments = max(1, int(math.sqrt(len(self.layers))))
            x = checkpoint_sequential(
                self.layers,
                num_segments,
                x,
                use_reentrant=False
            )
        else:
            for layer in self.layers:
                x = layer(x)

        return x

Memory Estimation with Checkpointing

Analytical Model

def estimate_activation_memory(
    num_layers: int,
    hidden_dim: int,
    num_heads: int,
    batch_size: int,
    seq_length: int,
    checkpoint_strategy: str = 'sqrt',
    checkpoint_attention: bool = True,
    dtype_bytes: int = 2
) -> dict:
    """
    Estimate activation memory with various checkpointing strategies.

    Args:
        num_layers: Number of transformer layers
        hidden_dim: Model hidden dimension
        num_heads: Number of attention heads
        batch_size: Batch size
        seq_length: Sequence length
        checkpoint_strategy: 'none', 'full', 'sqrt', or 'selective'
        checkpoint_attention: Whether to checkpoint attention scores (recompute instead of storing)
        dtype_bytes: Bytes per element (2 for FP16)

    Returns:
        Dictionary with memory estimates
    """
    B, S, H, n, L = batch_size, seq_length, hidden_dim, num_heads, num_layers

    # Per-layer activation memory (without attention scores)
    # Input, Q, K, V, attention output, FFN intermediate, misc
    linear_per_layer = 20 * B * S * H * dtype_bytes  # ~20 BSH

    # Attention scores and softmax output
    attention_per_layer = 2 * B * n * S * S * dtype_bytes  # 2 * BnS²

    if checkpoint_strategy == 'none':
        # Store everything (no recomputation)
        stored_per_layer = linear_per_layer + attention_per_layer
        recompute_per_layer = 0
    elif checkpoint_attention:
        # Store Q, K, V; recompute scores
        stored_per_layer = linear_per_layer
        recompute_per_layer = attention_per_layer
    else:
        stored_per_layer = linear_per_layer + attention_per_layer
        recompute_per_layer = 0

    # Apply checkpointing strategy
    if checkpoint_strategy == 'none':
        # Store all layers
        total_stored = L * stored_per_layer
        peak_recompute = 0

    elif checkpoint_strategy == 'full':
        # Store only input, recompute all
        total_stored = B * S * H * dtype_bytes  # Just input
        peak_recompute = stored_per_layer + recompute_per_layer  # One layer at a time

    elif checkpoint_strategy == 'sqrt':
        # Optimal sqrt(L) checkpointing
        num_checkpoints = int(math.ceil(math.sqrt(L)))
        segment_size = (L + num_checkpoints - 1) // num_checkpoints

        # Store checkpoints
        checkpoint_memory = num_checkpoints * B * S * H * dtype_bytes

        # Peak during segment recomputation
        peak_segment = segment_size * stored_per_layer

        total_stored = checkpoint_memory + peak_segment
        peak_recompute = segment_size * recompute_per_layer

    elif checkpoint_strategy == 'selective':
        # Checkpoint every other layer
        num_stored = L // 2
        num_recomputed = L - num_stored

        total_stored = num_stored * stored_per_layer
        peak_recompute = stored_per_layer  # Only 1 layer at a time

    else:
        raise ValueError(f"Unknown strategy: {checkpoint_strategy}")

    return {
        'total_stored_gb': total_stored / (1024**3),
        'peak_recompute_gb': peak_recompute / (1024**3),
        'peak_total_gb': (total_stored + peak_recompute) / (1024**3),
        'strategy': checkpoint_strategy,
        'checkpoint_attention': checkpoint_attention,
        'layers': L,
        'per_layer_mb': stored_per_layer / (1024**2)
    }

Example Calculations

# 7B model with batch 4, sequence 2048
config = {
    'num_layers': 32,
    'hidden_dim': 4096,
    'num_heads': 32,
    'batch_size': 4,
    'seq_length': 2048
}

strategies = ['none', 'sqrt', 'selective', 'full']
for strategy in strategies:
    result = estimate_activation_memory(**config, checkpoint_strategy=strategy)
    print(f"{strategy:10}: {result['peak_total_gb']:.1f} GB")

Output:

none      : 104.0 GB
sqrt      : 20.0 GB  (5× reduction)
selective : 21.3 GB  (~5× reduction)
full      : 3.3 GB   (~31× reduction)

Practical Implementation Guide

DeepSpeed Activation Checkpointing

DeepSpeed provides optimized checkpointing:

import deepspeed

# Configure in DeepSpeed config
ds_config = {
    "activation_checkpointing": {
        "partition_activations": True,
        "cpu_checkpointing": False,
        "contiguous_memory_optimization": True,
        "number_checkpoints": None,  # Auto-detect
        "synchronize_checkpoint_boundary": False,
        "profile": False
    }
}

# Wrap model
model, optimizer, _, _ = deepspeed.initialize(
    model=model,
    config=ds_config
)

Megatron-LM Checkpointing

from megatron.core.tensor_parallel import checkpoint

class MegatronTransformerLayer(nn.Module):
    def forward(self, hidden_states, attention_mask):
        if self.checkpoint_activations:
            hidden_states = checkpoint(
                self.attention,
                hidden_states,
                attention_mask
            )
            hidden_states = checkpoint(
                self.mlp,
                hidden_states
            )
        else:
            hidden_states = self.attention(hidden_states, attention_mask)
            hidden_states = self.mlp(hidden_states)

        return hidden_states

HuggingFace Gradient Checkpointing

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b")

# Enable gradient checkpointing
model.gradient_checkpointing_enable()

# Now training uses ~40% less activation memory

Exercises

  1. Optimal checkpointing: For a 48-layer transformer, derive the optimal checkpoint interval using the \(\sqrt{L}\) rule. How many checkpoints are stored? What's the memory compared to no checkpointing?
Solution

Optimal checkpoint interval:

Using the \(\sqrt{L}\) rule:

\[k^* = \sqrt{L} = \sqrt{48} \approx 6.93\]

Round to practical value: \(k = 7\) layers between checkpoints.

Number of checkpoints: $\(N_{\text{ckpt}} = \left\lceil \frac{L}{k} \right\rceil = \left\lceil \frac{48}{7} \right\rceil = \boxed{7 \text{ checkpoints}}\)$

Memory comparison:

Let \(M_{\text{layer}}\) = activation memory per layer.

Strategy Memory Formula
No checkpointing \(48 \cdot M_{\text{layer}}\) \(L \cdot M_{\text{layer}}\)
\(\sqrt{L}\) checkpointing \(7 \cdot M_{\text{layer}} + 7 \cdot M_{\text{layer}}\) \(2\sqrt{L} \cdot M_{\text{layer}}\)

Explanation of \(\sqrt{L}\) memory: - Store \(\sqrt{L}\) checkpoints: \(\sqrt{L} \cdot M_{\text{layer}}\) - During backward, recompute up to \(\sqrt{L}\) layers: \(\sqrt{L} \cdot M_{\text{layer}}\) - Total: \(2\sqrt{L} \cdot M_{\text{layer}}\)

Memory reduction: $\(\text{Reduction} = \frac{L}{2\sqrt{L}} = \frac{\sqrt{L}}{2} = \frac{\sqrt{48}}{2} = \boxed{3.46\times}\)$

Numerical example (assuming \(M_{\text{layer}} = 1\) GB):

Strategy Total Memory Relative
No checkpointing 48 GB 1.00×
\(\sqrt{L}\) (\(k=7\)) 14 GB 0.29×
Full checkpointing (\(k=1\)) 2 GB 0.04×

Trade-off summary:

Checkpoint Interval Checkpoints Memory Recompute Overhead
\(k = 1\) (every layer) 48 \(2M\) 100% (full recompute)
\(k = 7\) (\(\sqrt{L}\)) 7 \(14M\) 33%
\(k = 12\) 4 \(16M\) 25%
No checkpointing 0 \(48M\) 0%
  1. Compute overhead: A training step takes 100ms without checkpointing. With full checkpointing, what's the expected step time? (Assume forward = ⅓ of backward compute.)
Solution

Training step breakdown (without checkpointing):

Total = Forward + Backward = 100ms

Let \(F\) = forward time, \(B\) = backward time. Given: Forward = ⅓ of backward compute.

The backward pass has two components: - Recompute forward (if checkpointing): same as forward - Compute gradients: typically 2× forward

So without checkpointing: - Forward: \(F\) - Backward (gradients only): \(2F\) - Total: \(3F = 100\) ms - \(F = 33.3\) ms

With full checkpointing:

  • Forward pass: \(F = 33.3\) ms (same, just don't save activations)
  • Backward pass: recompute + gradients = \(F + 2F = 3F = 100\) ms

Total with checkpointing: $\(T_{\text{ckpt}} = F + 3F = 4F = \frac{4}{3} \times 100 = \boxed{133.3\text{ ms}}\)$

Overhead: $\(\text{Overhead} = \frac{133.3 - 100}{100} = \boxed{33.3\%}\)$

Verification using standard formula:

The 33% overhead matches the well-known rule: checkpointing adds ~33% compute overhead.

Breakdown comparison:

Phase No Checkpointing Full Checkpointing
Forward 33.3 ms 33.3 ms
Backward (recompute) 0 ms 33.3 ms
Backward (gradients) 66.7 ms 66.7 ms
Total 100 ms 133.3 ms
  1. Mixed strategy: Design a checkpointing strategy that checkpoints attention scores but not FFN activations. Calculate memory savings vs. full checkpointing.
Solution

Activation memory breakdown per layer:

Component Memory Recompute Cost
Attention Q, K, V \(3 \times BSH\) Low (linear projections)
Attention scores \(B \times A \times S^2\) High (matmul + softmax)
Attention output \(BSH\) Medium
FFN intermediate \(4 \times BSH\) Low (linear + activation)
FFN output \(BSH\) Low

For typical config (B=4, S=4096, H=4096, A=32):

Component Size % of Total
Attention Q,K,V 192 MB 12%
Attention scores 2048 MB 64%
Attention output 64 MB 4%
FFN intermediate 256 MB 16%
FFN output 64 MB 4%
Total 2624 MB 100%

Mixed strategy: Store Q/K/V, recompute attention scores + FFN

Store: - Q, K, V: 192 MB - Layer input: 64 MB (for recomputing FFN)

Recompute: - Attention scores + softmax: 2048 MB saved - FFN activations: 320 MB saved

Memory per layer: $\(M_{\text{mixed}} = 192 + 64 = 256 \text{ MB}\)$

Comparison:

Strategy Memory/Layer Relative Recompute Overhead
No checkpointing 2624 MB 1.00× 0%
Mixed (store QKV) 256 MB 0.10× ~20%
Full checkpointing 64 MB 0.02× 33%

When to use mixed strategy:

class MixedCheckpointAttention(nn.Module):
    """Stores Q/K/V, recomputes attention scores and FFN."""

    def forward(self, x):
        q, k, v = self.qkv_proj(x).chunk(3, dim=-1)

        def attn_core(q, k, v):
            scores = torch.einsum('bshd,bthd->bhst', q, k) / math.sqrt(self.head_dim)
            probs = torch.softmax(scores, dim=-1)
            return torch.einsum('bhst,bthd->bshd', probs, v)

        # Recompute attention core in backward (scores not stored)
        attn_out = checkpoint(attn_core, q, k, v, use_reentrant=False)

        # FFN with recomputation (cheap)
        return checkpoint(self.ffn, attn_out, use_reentrant=False)

Trade-off analysis:

Metric Mixed Full Ckpt Advantage
Memory 0.80× 0.02× Full ckpt
Compute 1.10× 1.33× Mixed
Complexity Medium Low Full ckpt

Recommendation: Use mixed strategy when memory is tight but 33% overhead is unacceptable. The 20% memory savings with only 10% overhead can be worthwhile.

  1. Compression analysis: If activations are compressed to FP8 before checkpointing, what's the memory reduction compared to FP16 checkpointing? What's the impact on gradient accuracy?
Solution

Memory reduction:

Precision Bytes/element Relative
FP32 4 2.0×
FP16/BF16 2 1.0× (baseline)
FP8 1 0.5×

Memory reduction with FP8: $\(\text{Reduction} = \frac{2}{1} = \boxed{2\times \text{ (50\% savings)}}\)$

For a 48-layer model with \(\sqrt{L}\) checkpointing:

Checkpoint Precision Memory
FP16 \(14 \cdot M_{\text{layer}}\)
FP8 \(7 \cdot M_{\text{layer}}\)

Gradient accuracy impact:

FP8 has limited dynamic range and precision: - E4M3: range \(\pm 448\), 3 mantissa bits - E5M2: range \(\pm 57344\), 2 mantissa bits

Error analysis:

import torch
import numpy as np

def analyze_fp8_error(tensor_fp16):
    """Analyze quantization error from FP16 → FP8 → FP16."""
    # Simulate FP8 quantization (E4M3)
    scale = tensor_fp16.abs().max() / 448  # FP8 E4M3 max
    quantized = (tensor_fp16 / scale).clamp(-448, 448)
    quantized = (quantized * 8).round() / 8  # 3 mantissa bits
    dequantized = quantized * scale

    # Error metrics
    abs_error = (tensor_fp16 - dequantized).abs()
    rel_error = abs_error / (tensor_fp16.abs() + 1e-8)

    return {
        'max_abs_error': abs_error.max().item(),
        'mean_abs_error': abs_error.mean().item(),
        'max_rel_error': rel_error.max().item(),
        'mean_rel_error': rel_error.mean().item(),
    }

# Test on typical activation distribution
activations = torch.randn(1024, 4096) * 0.1  # Typical scale
errors = analyze_fp8_error(activations)
# Typical results:
# mean_rel_error: ~2-5%
# max_rel_error: ~10-20%

Gradient accuracy impact:

Metric FP16 Baseline FP8 Checkpoints
Gradient rel. error 0% 1-3%
Training stability Stable Usually stable
Final loss Baseline +0.1-0.5%
Convergence speed Baseline ~Same

Mitigation strategies:

  1. Stochastic rounding: Reduces bias in quantization
  2. Per-tensor scaling: Maximizes dynamic range usage
  3. Mixed precision checkpoints: FP8 for large tensors, FP16 for small
  4. Gradient scaling: Compensate for reduced precision
class FP8Checkpoint:
    """Checkpoint with FP8 compression."""

    @staticmethod
    def save(tensor):
        scale = tensor.abs().max() / 448
        quantized = (tensor / scale).to(torch.float8_e4m3fn)
        return quantized, scale

    @staticmethod
    def load(quantized, scale):
        return quantized.to(torch.float16) * scale

Summary:

Aspect Impact
Memory 2× reduction
Gradient accuracy 1-3% relative error
Training stability Generally maintained
Recommended for Memory-constrained training
  1. Combined optimization: A 70B model is trained with ZeRO-3 on 64 GPUs with \(\sqrt{L}\) checkpointing. Calculate total memory per GPU including model state and activations.
Solution

70B model architecture (typical):

Parameter Value
Total parameters 70B
Layers (\(L\)) 80
Hidden dimension (\(H\)) 8192
Attention heads (\(A\)) 64
FFN intermediate 28672 (3.5× H)

Model state memory (ZeRO-3, 64 GPUs):

\[M_{\text{state}} = \frac{16\Psi}{P} = \frac{16 \times 70 \times 10^9}{64} = \boxed{17.5\text{ GB}}\]

Activation memory with \(\sqrt{L}\) checkpointing:

Checkpoint interval: \(k = \sqrt{80} \approx 9\) layers Number of checkpoints: \(\lceil 80/9 \rceil = 9\)

Assume training config: \(B = 1\) (micro-batch), \(S = 4096\)

Per-checkpoint activation size:

Layer input: \(B \times S \times H \times 2\) bytes = \(1 \times 4096 \times 8192 \times 2 = 67\) MB

Recomputation buffer (max 9 layers active):

Per-layer activations (approximate): - Attention: \(BSH \times 10 \times 2 = 670\) MB - FFN: \(BSH \times 5 \times 2 = 335\) MB - Total per layer: ~1 GB

Recomputation buffer: \(9 \times 1\) GB = 9 GB

Total activation memory: $\(M_{\text{act}} = 9 \times 67\text{ MB} + 9 \times 1\text{ GB} \approx 0.6 + 9 = \boxed{9.6\text{ GB}}\)$

Total memory per GPU:

Component Memory
Model state (ZeRO-3) 17.5 GB
Activations (\(\sqrt{L}\) ckpt) 9.6 GB
Temporary buffers ~3 GB
CUDA overhead ~2 GB
Total ~32 GB

Fits comfortably in 80GB H100!

Comparison without optimizations:

Configuration Memory/GPU
No ZeRO, no ckpt \(16\Psi + 80 \times 1\text{ GB} = 1200\) GB
ZeRO-3 only \(17.5 + 80 = 97.5\) GB
Ckpt only (8 GPUs) \(140 + 9.6 = 150\) GB
ZeRO-3 + \(\sqrt{L}\) ckpt 32 GB

Key insight: Combining ZeRO-3 and checkpointing provides multiplicative benefits—neither alone is sufficient for 70B on 64 GPUs.

  1. Pipeline interaction: In a 4-stage pipeline with 8 micro-batches, how many activation copies are stored at peak? How does checkpointing affect this?
Solution

Pipeline configuration:

  • Stages (\(p\)): 4
  • Micro-batches (\(m\)): 8
  • Schedule: 1F1B (one forward, one backward)

1F1B Schedule Analysis:

Stage 0: F0 F1 F2 F3 F4 F5 F6 F7 B0 B1 B2 B3 B4 B5 B6 B7
Stage 1:    F0 F1 F2 F3 F4 F5 F6 B0 B1 B2 B3 B4 B5 B6 B7
Stage 2:       F0 F1 F2 F3 F4 F5 B0 B1 B2 B3 B4 B5 B6 B7
Stage 3:          F0 F1 F2 F3 F4 B0 B1 B2 B3 B4 B5 B6 B7
                  ↑ Peak activations

Peak activation storage (without checkpointing):

At peak (just before first backward starts): - Stage 0: Holds activations for micro-batches 0-7 = 8 copies - Stage 1: Holds activations for micro-batches 0-6 = 7 copies - Stage 2: Holds activations for micro-batches 0-5 = 6 copies - Stage 3: Holds activations for micro-batches 0-4 = 5 copies

Wait, let me recalculate for 1F1B steady state...

1F1B steady state:

After warmup, each stage holds at most \(p\) in-flight micro-batches: - Stage 0: up to 4 activation sets - Stage 1: up to 4 activation sets - etc.

Peak activations per stage: \(\boxed{p = 4 \text{ micro-batches}}\)

But during warmup (bubble fill):

Stage 0 must hold activations until backward arrives: - Warmup micro-batches: \(p - 1 = 3\) - Steady state in-flight: 1

Total peak: \(p = 4\) per stage (steady state bound)

With checkpointing:

Instead of storing full activations, store only checkpoints:

Strategy Per-micro-batch Storage Peak per Stage
No checkpointing Full activations (\(L_{stage} \times M\)) \(4 \times L_{stage} \times M\)
Full checkpointing Input only (\(M\)) \(4 \times M\)
\(\sqrt{L}\) checkpointing \(2\sqrt{L_{stage}} \times M\) \(4 \times 2\sqrt{L_{stage}} \times M\)

Example (70B model, 4 stages, 20 layers/stage):

\(M\) = layer activation ≈ 64 MB (for micro-batch size 1, S=4096, H=8192)

Strategy Memory per Stage
No checkpointing \(4 \times 20 \times 64\) MB = 5.12 GB
Full checkpointing \(4 \times 64\) MB = 256 MB
\(\sqrt{L}\) (\(k=4\)) \(4 \times 2 \times 4.5 \times 64\) MB = 2.3 GB

Interleaved Pipeline (more micro-batches in flight):

With interleaved schedule (virtual stages), peak increases:

\[\text{Peak} = \frac{p \times v}{v} = p\]

Still bounded by \(p\) per physical stage, but more total memory across all virtual stages.

Memory formula:

\[M_{\text{peak}} = p \times \frac{L}{p} \times M_{\text{layer}} \times \text{ckpt\_factor}\]

Where ckpt_factor: - No checkpointing: 1.0 - \(\sqrt{L}\): \(\frac{2\sqrt{L/p}}{L/p} = \frac{2}{\sqrt{L/p}}\) - Full: \(\frac{1}{L/p}\)

Summary:

Metric No Ckpt \(\sqrt{L}\) Ckpt Full Ckpt
Peak copies \(4 \times 20 = 80\) \(4 \times 9 = 36\) \(4 \times 1 = 4\)
Memory 5.12 GB 2.3 GB 256 MB
Recompute overhead 0% 33% 100%

Key Takeaways

  1. Activations dominate memory: Can exceed model size by 10×+ for large batch/sequence.

  2. Checkpointing trades compute for memory: 33% compute overhead for \(\sqrt{L}\) memory reduction.

  3. \(\sqrt{L}\) is optimal: Checkpoint every \(\sqrt{L}\) layers for best memory-compute trade-off.

  4. Attention scores are expensive: \(O(S^2)\) per layer; prime candidates for recomputation.

  5. Composable with other techniques: Works with ZeRO, TP, PP for multiplicative savings.

  6. Use non-reentrant checkpointing: More robust, correctly handles RNG state.