flowchart TD
A[Model fits in GPU<br/>with batch size 1?] -->|Yes| B[Try gradient checkpointing]
A -->|No| C[Need offloading]
B --> D{Still OOM?}
D -->|No| E[Done!]
D -->|Yes| F[Smaller batch +<br/>gradient accumulation]
F --> G{Still OOM?}
G -->|No| E
G -->|Yes| H[CPU offload optimizer]
C --> I[Optimizer-only offload<br/>ZeRO-2 + offload<br/>~2x slowdown]
C --> J[Full offload<br/>ZeRO-3 + offload<br/>~5x slowdown]
C --> K[NVMe offload<br/>~10x+ slowdown<br/>fits huge models]
style E fill:#dcfce7,stroke:#16a34a
style H fill:#fef3c7,stroke:#d97706
style I fill:#e0f2fe,stroke:#0284c7
style J fill:#f3e8ff,stroke:#9333ea
style K fill:#fee2e2,stroke:#dc2626
24 Mastering GPU Memory
Understanding Allocation, Fragmentation, and Optimization
“CUDA out of memory” is the most common training failure.
Understanding how GPU memory actually works—not just how much you have—is the difference between running your model and staring at error messages.
24.1 The Memory Landscape
A typical training run uses GPU memory for:
LLaMA-7B training breakdown (batch size 1, seq 2048):
Model parameters: 14.0 GB (FP16)
Gradients: 14.0 GB (FP16)
Optimizer state: 28.0 GB (FP32 Adam momentum + variance)
Activations: ~8.0 GB (depends on checkpointing)
Temporary buffers: ~2.0 GB (cuBLAS workspace, etc.)
─────────────────────────────────────────────────────────
Total: ~66 GB
A100 80GB: Fits
A100 40GB: Doesn't fit
This chapter explains each component and how to optimize it.
24.2 PyTorch CUDA Memory Allocator
24.2.1 The Caching Allocator
PyTorch doesn’t call cudaMalloc for every tensor. That would be slow:
cudaMalloc latency: 1-10 ms
Tensor operation: 0.01-1 ms
If we malloc per tensor: Memory management >> actual compute
Instead, PyTorch uses a caching allocator:
# Conceptual model of PyTorch's allocator
class CachingAllocator:
def __init__(self):
self.free_blocks = {} # size -> list of free blocks
self.allocated_blocks = {} # ptr -> (size, block)
def allocate(self, size):
# Round up to allocation granularity
size = round_up(size, 512) # 512-byte granularity
# Try to find a cached free block
if size in self.free_blocks and self.free_blocks[size]:
block = self.free_blocks[size].pop()
self.allocated_blocks[block.ptr] = (size, block)
return block.ptr
# No cached block—actually allocate from CUDA
ptr = cudaMalloc(size)
block = Block(ptr, size)
self.allocated_blocks[ptr] = (size, block)
return ptr
def free(self, ptr):
size, block = self.allocated_blocks.pop(ptr)
# Don't actually free—cache for reuse
if size not in self.free_blocks:
self.free_blocks[size] = []
self.free_blocks[size].append(block)24.2.2 Inspecting Memory State
import torch
# Current allocations
allocated = torch.cuda.memory_allocated() # Tensors you're using
reserved = torch.cuda.memory_reserved() # Total held by allocator
print(f"Allocated: {allocated / 1e9:.2f} GB")
print(f"Reserved: {reserved / 1e9:.2f} GB")
print(f"Free (in cache): {(reserved - allocated) / 1e9:.2f} GB")
# Detailed statistics
stats = torch.cuda.memory_stats()
print(f"Current allocated: {stats['allocated_bytes.all.current'] / 1e9:.2f} GB")
print(f"Peak allocated: {stats['allocated_bytes.all.peak'] / 1e9:.2f} GB")
print(f"Num allocs: {stats['allocation.all.current']}")24.2.3 Memory Snapshots
For debugging OOM, capture a memory snapshot:
# Enable memory history
torch.cuda.memory._record_memory_history(max_entries=100000)
# Run your code
try:
train_step()
except torch.cuda.OutOfMemoryError:
# Capture snapshot
torch.cuda.memory._dump_snapshot("oom_snapshot.pickle")
# Visualize with:
# python -m torch.cuda.memory_viz oom_snapshot.pickle -o snapshot.htmlThe visualization shows: - Timeline of allocations - What tensors are live at OOM - Stack traces for each allocation
24.3 Memory Fragmentation
24.3.1 How Fragmentation Happens
Scenario: Allocate and free tensors of varying sizes
Step 1: Allocate A (1GB), B (2GB), C (1GB)
Memory: [A:1GB][B:2GB][C:1GB][free:4GB]
Step 2: Free B
Memory: [A:1GB][free:2GB][C:1GB][free:4GB]
Step 3: Try to allocate D (3GB)
Problem: Have 6GB free, but no contiguous 3GB block!
PyTorch options:
1. Fail with OOM (even though total free > requested)
2. Call cudaMalloc for new block (may fail at CUDA level)
24.3.2 Detecting Fragmentation
def check_fragmentation():
"""Diagnose memory fragmentation."""
allocated = torch.cuda.memory_allocated()
reserved = torch.cuda.memory_reserved()
total = torch.cuda.get_device_properties(0).total_memory
fragmentation = (reserved - allocated) / reserved if reserved > 0 else 0
print(f"Allocated: {allocated / 1e9:.2f} GB")
print(f"Reserved: {reserved / 1e9:.2f} GB")
print(f"Total: {total / 1e9:.2f} GB")
print(f"Fragmentation: {fragmentation * 100:.1f}%")
print(f"Truly free: {(total - reserved) / 1e9:.2f} GB")
if fragmentation > 0.3:
print("WARNING: High fragmentation. Consider torch.cuda.empty_cache()")24.3.3 Mitigating Fragmentation
1. Consistent tensor sizes
# Bad: Variable sizes cause fragmentation
for seq_len in [128, 256, 512, 1024, 2048]:
x = torch.randn(batch, seq_len, hidden) # Different sizes
process(x)
# Better: Pad to consistent size
max_len = 2048
for seq_len in [128, 256, 512, 1024, 2048]:
x = torch.randn(batch, max_len, hidden) # Same size
x = x[:, :seq_len, :] # View (no new allocation)
process(x)2. Pre-allocate buffers
class PreallocatedModel(nn.Module):
def __init__(self, max_batch, max_seq, hidden):
super().__init__()
# Pre-allocate work buffers
self.register_buffer(
'work_buffer',
torch.empty(max_batch, max_seq, hidden)
)
def forward(self, x):
batch, seq, _ = x.shape
# Use pre-allocated buffer
work = self.work_buffer[:batch, :seq, :]
work.copy_(x)
# ... process work ...3. Empty cache strategically
# Clear cached memory (returns to CUDA, may be slow)
torch.cuda.empty_cache()
# Use sparingly—once per epoch, not per batch
for epoch in range(num_epochs):
train_epoch()
torch.cuda.empty_cache() # Defragment between epochs4. Memory-efficient operations
# Bad: Creates intermediate tensor
y = x + 1
y = y * 2 # x + 1 is still allocated until y * 2 completes
# Better: In-place operations
y = x.add_(1).mul_(2) # No intermediate allocation
# Best: Fused operation
@torch.compile
def fused_op(x):
return (x + 1) * 2 # Compiler fuses automatically24.4 Activation Memory
24.4.1 Why Activations Dominate
Forward pass saves activations for backward:
Transformer layer memory (batch=8, seq=2048, hidden=4096):
Input: 8 × 2048 × 4096 × 2 = 128 MB
After attention: 128 MB
After FFN: 128 MB
LayerNorm saves: 64 MB
Attention scores: 8 × 32 × 2048 × 2048 × 2 = 2 GB (!)
Per layer: ~2.5 GB
80 layers: ~200 GB
This is why we need activation checkpointing.
24.4.2 Activation Checkpointing
Trade compute for memory—don’t save activations, recompute during backward:
from torch.utils.checkpoint import checkpoint
class CheckpointedTransformerBlock(nn.Module):
def forward(self, x):
# Checkpoint attention
x = x + checkpoint(self.attention, x, use_reentrant=False)
# Checkpoint FFN
x = x + checkpoint(self.ffn, x, use_reentrant=False)
return xMemory saved: Activations only stored at checkpoint boundaries
Compute cost: ~33% more FLOPs (recompute forward during backward)
24.4.3 Selective Checkpointing
Not all layers need checkpointing:
def apply_selective_checkpointing(model, checkpoint_ratio=0.5):
"""
Checkpoint only some layers.
Early layers: High memory (close to input), checkpoint these
Late layers: Low memory (near output), don't checkpoint
"""
layers = list(model.layers)
n_checkpoint = int(len(layers) * checkpoint_ratio)
for i, layer in enumerate(layers):
if i < n_checkpoint:
layer.use_checkpoint = True
else:
layer.use_checkpoint = False24.4.4 Optimal Checkpointing
The optimal strategy minimizes peak memory:
def compute_optimal_checkpoints(layer_memories, budget):
"""
Dynamic programming to find optimal checkpoint placement.
layer_memories: Memory used by each layer's activations
budget: Maximum allowed memory
"""
# Use DP to find minimum checkpoints needed
n = len(layer_memories)
# dp[i] = min checkpoints for layers 0..i
# Simplified: checkpoint every sqrt(n) layers is near-optimal
checkpoint_interval = int(math.sqrt(n))
checkpoints = list(range(0, n, checkpoint_interval))
return checkpoints24.5 Optimizer State Memory
24.5.1 Adam’s Memory Cost
Adam stores two additional tensors per parameter:
# Adam update:
# m = beta1 * m + (1 - beta1) * grad # First moment
# v = beta2 * v + (1 - beta2) * grad^2 # Second moment
# param -= lr * m / (sqrt(v) + eps)
# Memory:
# Parameters: P × sizeof(dtype)
# Gradients: P × sizeof(dtype)
# m (momentum): P × sizeof(float32) — always FP32!
# v (variance): P × sizeof(float32) — always FP32!
# For 7B params, FP16 training:
# Params: 14 GB
# Grads: 14 GB
# m: 28 GB (FP32)
# v: 28 GB (FP32)
# Total optimizer state: 56 GB24.5.2 Memory-Efficient Optimizers
8-bit Adam (bitsandbytes):
import bitsandbytes as bnb
# Replace Adam with 8-bit version
optimizer = bnb.optim.Adam8bit(
model.parameters(),
lr=1e-4,
betas=(0.9, 0.999)
)
# Memory: m and v stored in INT8
# 7B model: 56 GB → 14 GB optimizer state
# Quality: Minimal degradation with proper scalingAdafactor:
from transformers.optimization import Adafactor
# Adafactor uses factored second moments
optimizer = Adafactor(
model.parameters(),
lr=1e-3,
relative_step=False,
scale_parameter=False
)
# Instead of full v matrix, stores row + column factors
# Memory: O(rows + cols) instead of O(rows × cols)
# Works well for large matricesSGD with momentum:
# Sometimes the simplest solution
optimizer = torch.optim.SGD(
model.parameters(),
lr=0.1,
momentum=0.9
)
# Memory: Only one extra tensor (momentum)
# 7B model: 28 GB optimizer state (vs 56 GB for Adam)
# Downside: May need more tuning, different learning dynamics24.6 Memory Offloading
When GPU memory is insufficient, offload to CPU or NVMe.
24.6.1 CPU Offloading
# DeepSpeed ZeRO-Offload
import deepspeed
ds_config = {
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": True # Faster CPU-GPU transfer
}
}
}
model, optimizer, _, _ = deepspeed.initialize(
model=model,
config=ds_config
)How it works: 1. Optimizer states live on CPU 2. Gradients computed on GPU, sent to CPU 3. Optimizer step on CPU 4. Updated params sent back to GPU
Performance: 2-3x slower than pure GPU, but enables larger models.
24.6.2 NVMe Offloading
For even larger models:
ds_config = {
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "nvme",
"nvme_path": "/local_nvme"
},
"offload_param": {
"device": "nvme",
"nvme_path": "/local_nvme"
}
}
}Requirements: - Fast NVMe SSD (>3 GB/s read/write) - Sufficient CPU RAM for working set - 10-100x slower than pure GPU
24.6.3 When to Offload
24.7 Mixed Precision Memory Savings
24.7.1 The Automatic Savings
from torch.cuda.amp import autocast
with autocast():
# Activations stored in FP16: 2x memory reduction
output = model(input)
# Comparison (batch=8, seq=2048, hidden=4096):
# FP32 activations: 256 MB
# FP16 activations: 128 MB24.7.2 Master Weights Pattern
Keep FP32 copy for optimizer, FP16 for forward/backward:
class MixedPrecisionWrapper:
def __init__(self, model):
# FP16 model for forward/backward
self.model_fp16 = copy.deepcopy(model).half()
# FP32 copy for optimizer updates
self.model_fp32 = model
def forward(self, x):
return self.model_fp16(x.half())
def optimizer_step(self, optimizer):
# Copy FP16 grads to FP32
for p16, p32 in zip(self.model_fp16.parameters(),
self.model_fp32.parameters()):
if p16.grad is not None:
p32.grad = p16.grad.float()
# Update in FP32
optimizer.step()
# Copy back to FP16
for p16, p32 in zip(self.model_fp16.parameters(),
self.model_fp32.parameters()):
p16.data.copy_(p32.data)24.8 Debugging OOM
24.8.1 Systematic Approach
def diagnose_oom():
"""Run before training to diagnose memory issues."""
device = torch.device('cuda')
props = torch.cuda.get_device_properties(device)
print(f"GPU: {props.name}")
print(f"Total memory: {props.total_memory / 1e9:.2f} GB")
# Test maximum allocation
test_sizes = [1, 2, 4, 8, 16, 32, 64] # GB
max_allocatable = 0
for size in test_sizes:
try:
x = torch.empty(int(size * 1e9 / 4), dtype=torch.float32, device=device)
max_allocatable = size
del x
torch.cuda.empty_cache()
except:
break
print(f"Max single allocation: {max_allocatable} GB")
# Check current state
print(f"\nCurrent state:")
print(f" Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f" Reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
# Call before training
diagnose_oom()24.8.2 Common OOM Patterns
1. Accumulating gradients incorrectly
# Bug: Gradients accumulate without clearing
for batch in dataloader:
loss = model(batch)
loss.backward() # Gradients accumulate!
# Missing: optimizer.step() and optimizer.zero_grad()
# Fix:
for batch in dataloader:
optimizer.zero_grad()
loss = model(batch)
loss.backward()
optimizer.step()2. Holding references to tensors
# Bug: Storing tensors in list
all_outputs = []
for batch in dataloader:
output = model(batch)
all_outputs.append(output) # Keeps all outputs in memory!
# Fix: Detach and move to CPU, or process incrementally
all_outputs = []
for batch in dataloader:
output = model(batch)
all_outputs.append(output.detach().cpu())3. Large intermediate tensors
# Bug: Huge intermediate
def forward(self, x):
# This creates a massive N×N matrix
similarity = x @ x.T # If x is 1M × 1K, this is 1M × 1M = 4TB!
return similarity.softmax(-1) @ x
# Fix: Chunked computation (like FlashAttention)
def forward(self, x):
return chunked_attention(x, x, x, chunk_size=1024)24.9 Key Takeaways
Understand the allocator: PyTorch caches memory;
reserved != allocated.Fragmentation is real: Variable tensor sizes cause holes. Use consistent sizes.
Activations dominate: Use gradient checkpointing to trade compute for memory.
Optimizer states are huge: 8-bit Adam or Adafactor can save 4x memory.
Offloading is a last resort: 2-10x slower, but enables larger models.
Debug systematically: Memory snapshots show exactly what’s consuming space.
Measure before optimizing: Know your memory breakdown before applying fixes.
24.10 Further Reading
- PyTorch CUDA Semantics
- PyTorch Memory Management
- Rajbhandari et al. (2020). “ZeRO: Memory Optimizations Toward Training Trillion Parameter Models”
- Dettmers et al. (2022). “8-bit Optimizers via Block-wise Quantization”