20 Distributed Training
When One GPU Isn’t Enough
GPT-3 has 175 billion parameters. At FP16, that’s 350GB. The largest single GPU has 80GB of memory.
The math doesn’t work. Training large models requires distribution across multiple GPUs—and that introduces fundamental challenges in parallelism, communication, and memory.
This chapter exploits associativity—the first property from our Algebraic Framework.
Gradient averaging is associative: \((g_1 + g_2) + g_3 = g_1 + (g_2 + g_3)\). This enables ring-allreduce, tree reductions, and arbitrary partitioning of gradient synchronization. Matrix multiplications in tensor parallelism also exploit associativity—we can split along different dimensions because the underlying algebra permits regrouping.
Without associativity, distributed training would require centralized coordination for every operation.
20.1 The Three Dimensions of Parallelism
Distributed training can parallelize along three axes:
┌─────────────────────────────────────┐
│ DATA PARALLEL │
│ Each GPU: same model, different │
│ batch │
└─────────────────────────────────────┘
┌─────────────────────────────────────┐
│ TENSOR PARALLEL │
│ Each GPU: part of each layer │
└─────────────────────────────────────┘
┌─────────────────────────────────────┐
│ PIPELINE PARALLEL │
│ Each GPU: different layers │
└─────────────────────────────────────┘
Each has different tradeoffs. Most large-scale training uses all three simultaneously.
20.2 Data Parallelism
The simplest approach: replicate the model on each GPU, give each a different batch.
# Conceptually:
# GPU 0: model(batch_0)
# GPU 1: model(batch_1)
# GPU 2: model(batch_2)
# GPU 3: model(batch_3)
# After forward + backward:
# Synchronize gradients across all GPUs
# Update model parameters20.2.1 Naive Data Parallelism
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
model = MyTransformer()
model = DDP(model)
for batch in dataloader:
loss = model(batch)
loss.backward() # Backward pass
# DDP automatically all-reduces gradients here
optimizer.step() # Update weightsMemory requirement: Full model on each GPU.
Communication: Gradients all-reduced after each microbatch.
Scaling: Near-linear up to communication becoming a bottleneck.
20.2.2 The Communication Bottleneck
For a model with P parameters:
- Forward/backward compute: O(P × batch_size)
- Gradient communication: O(P)
As you add GPUs: - Compute per GPU decreases (smaller batches) - Communication per GPU stays constant (still need all gradients)
Eventually, communication dominates.
Example: GPT-3-scale model (175B parameters, FP16 = 350GB gradients):
8 GPUs: 350GB / 8-GPU NVLink bandwidth (600 GB/s) = 0.58s
64 GPUs: 350GB / 64-GPU IB bandwidth (200 GB/s) = 1.75s
At 64 GPUs, if your batch takes 2 seconds to compute, you spend nearly half your time in communication.
20.2.3 ZeRO: Zero Redundancy Optimizer
The observation: we don’t need the full model on every GPU if we’re clever about when we materialize it.
ZeRO Stage 1: Partition optimizer states - Adam has 2× parameters in momentum/variance - 175B params → 350GB + 700GB optimizer state = 1050GB - Split optimizer state across GPUs - Memory per GPU: 350GB + (700GB / N)
ZeRO Stage 2: Partition gradients too - Gradients are only needed for the parameters you’re updating - Memory per GPU: 350GB + (700GB / N) + (350GB / N)
ZeRO Stage 3: Partition parameters - Each GPU only stores a shard of model parameters - All-gather needed parameters just-in-time during forward/backward - Memory per GPU: (350GB / N) + (700GB / N) + (350GB / N) = 1400GB / N
For 16 GPUs: - Naive DDP: 1050GB per GPU (doesn’t fit) - ZeRO-3: 87.5GB per GPU (fits!)
The cost: more communication (all-gather during forward, reduce-scatter during backward).
20.2.4 FSDP: Fully Sharded Data Parallelism
PyTorch’s implementation of ZeRO-3 concepts:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
model = MyTransformer()
model = FSDP(model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=MixedPrecisionPolicy(...))
# Training loop looks identical
for batch in dataloader:
loss = model(batch)
loss.backward()
optimizer.step()FSDP automatically: - Shards parameters across GPUs - All-gathers parameters just-in-time for forward - Discards full parameters after use - Reduces gradients across GPUs - Shards gradients for optimizer update
Tradeoff: Memory for communication.
20.3 Tensor Parallelism
Instead of replicating the model, split each tensor across GPUs.
For a linear layer \(Y = XW\) where \(W\) is [d_{in} × d_{out}]:
Row-parallel: Split W row-wise
GPU 0: Y_0 = X @ W_0 (W_0 is [d_in × d_out/4])
GPU 1: Y_1 = X @ W_1
GPU 2: Y_2 = X @ W_2
GPU 3: Y_3 = X @ W_3
Output: concatenate([Y_0, Y_1, Y_2, Y_3])
Column-parallel: Split W column-wise
GPU 0: Y = X_0 @ W (X_0 is [..., d_in/4])
GPU 1: Y = X_1 @ W
GPU 2: Y = X_2 @ W
GPU 3: Y = X_3 @ W
Output: Y_0 + Y_1 + Y_2 + Y_3 (all-reduce)
20.3.1 Megatron-Style Tensor Parallelism
For a transformer block:
┌─────────────────┐
Input ───> │ Attention QKV │ (column-parallel)
└─────────────────┘
↓ (all-gather)
┌─────────────────┐
│ Attention Compute│ (local)
└─────────────────┘
↓
┌─────────────────┐
│ Attention Out │ (row-parallel)
└─────────────────┘
↓ (all-reduce)
┌─────────────────┐
│ MLP 1 │ (column-parallel)
└─────────────────┘
↓ (all-gather)
┌─────────────────┐
│ GeLU │ (local)
└─────────────────┘
↓
┌─────────────────┐
│ MLP 2 │ (row-parallel)
└─────────────────┘
↓ (all-reduce)
Output
Two all-reduces per transformer block.
Memory: Model size / N per GPU
Communication: 2 all-reduces per block × activations
Scaling: Limited by communication bandwidth. Typically 2-8 GPUs.
20.4 Pipeline Parallelism
Split the model vertically: different layers on different GPUs.
GPU 0: Layers 0-7 ─→ GPU 1: Layers 8-15
↓
GPU 3: Layers 24-31 ←─ GPU 2: Layers 16-23
20.4.1 The Naive Problem: Bubbles
Time ─→
GPU 0: [F1][F2][F3][F4] [B4][B3][B2][B1]
GPU 1: [F1][F2][F3][F4] [B4][B3][B2][B1]
GPU 2: [F1][F2][F3][F4][B4][B3][B2][B1]
GPU 3: [F1][F2][F3][F4][B3][B2][B1]
└─ bubbles ─┘ └─ bubbles ─┘
Large sections where GPUs are idle (“pipeline bubbles”).
Utilization: ~50% for 4-stage pipeline
20.4.2 GPipe: Microbatches
Split each batch into microbatches:
GPU 0: [F1][F2][F3][F4][F5][F6][F7][F8]...[B8][B7]...[B1]
GPU 1: [F1][F2][F3][F4][F5][F6][F7]...[B7][B6]...[B1]
GPU 2: [F1][F2][F3][F4][F5][F6]...[B6][B5]...[B1]
GPU 3: [F1][F2][F3][F4][F5]...[B5][B4]...[B1]
└─small─┘ └─small─┘
With M microbatches and N pipeline stages:
Bubble time fraction: \(\frac{N - 1}{M + N - 1}\)
For M = 32, N = 4: Bubble = 3/35 = 8.6%
Tradeoff: More microbatches → less bubble time, but activation memory grows linearly with microbatches.
20.4.3 PipeDream: 1F1B Schedule
Alternate forward and backward passes:
GPU 0: [F1][F2][F3][F4][B1][F5][B2][F6][B3][F7][B4]...
GPU 1: [F1][F2][F3][B1][F4][B2][F5][B3][F6][B4]...
GPU 2: [F1][F2][B1][F3][B2][F4][B3][F5][B4]...
GPU 3: [F1][B1][F2][B2][F3][B3][F4][B4]...
Key insight: Release activation memory as soon as backward completes.
Memory: Only store activations for in-flight microbatches (constant with respect to total microbatches).
20.5 Combining All Three: 3D Parallelism
Real large-scale training uses data parallel + tensor parallel + pipeline parallel simultaneously.
Example: Training GPT-3
Model: 175B parameters
Hardware: 1024 A100 GPUs (80GB each)
Strategy:
- Pipeline parallel: 8 stages (different layers)
- Tensor parallel: 8-way (within each stage)
- Data parallel: 16-way (across pipeline replicas)
8 × 8 × 16 = 1024 GPUs ✓
Per-GPU memory:
- Model: 175B / (8 × 8) = 2.7B params × 2 bytes = 5.4GB
- Optimizer (ADAM): 5.4GB × 3 = 16.2GB
- Activations: ~40GB (depends on microbatch size)
- Total: ~62GB (fits in 80GB)
20.6 Communication Primitives
All distributed training relies on collective operations:
20.6.1 All-Reduce
GPU 0: [a, b, c]
GPU 1: [d, e, f] ─→ All-Reduce ─→ GPU 0: [a+d, b+e, c+f]
GPU 2: [g, h, i] GPU 1: [a+d, b+e, c+f]
GPU 2: [a+d, b+e, c+f]
Used for: Gradient synchronization in data parallelism
Time complexity: O(N × data_size / bandwidth) with ring algorithm
20.6.2 All-Gather
GPU 0: [a, b]
GPU 1: [c, d] ─→ All-Gather ─→ GPU 0: [a, b, c, d, e, f]
GPU 2: [e, f] GPU 1: [a, b, c, d, e, f]
GPU 2: [a, b, c, d, e, f]
Used for: Materializing full parameters in ZeRO/FSDP
20.6.3 Reduce-Scatter
GPU 0: [a, b, c, d, e, f]
GPU 1: [g, h, i, j, k, l] ─→ Reduce-Scatter ─→ GPU 0: [a+g, b+h]
GPU 2: [m, n, o, p, q, r] GPU 1: [c+i, d+j]
GPU 2: [e+k, f+l]
Used for: Distributing gradients in ZeRO/FSDP
20.7 The Interconnect Hierarchy
Communication speed depends on which GPUs are communicating:
NVLink (within a node): - 8× A100: 600 GB/s bidirectional per GPU - Very fast, all-to-all connectivity
InfiniBand (across nodes): - 200-800 Gb/s per node (25-100 GB/s) - Higher latency than NVLink
Strategy: Minimize cross-node communication - Tensor parallel: within nodes (tight, high-bandwidth requirement) - Pipeline parallel: across nodes (sequential, latency-tolerant) - Data parallel: across nodes (large, infrequent all-reduce)
20.8 Practical Considerations
20.8.1 Gradient Accumulation
To train with larger effective batch sizes:
accumulation_steps = 4
for i, batch in enumerate(dataloader):
loss = model(batch) / accumulation_steps
loss.backward() # Accumulate gradients
if (i + 1) % accumulation_steps == 0:
optimizer.step() # Update weights
optimizer.zero_grad()Memory: Same as smaller batch (only activations for one microbatch)
Compute: Same total FLOPs as large batch
Tradeoff: Slightly less efficient (more overhead), but enables training with limited memory
20.8.2 Mixed Precision
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for batch in dataloader:
with autocast(): # FP16 forward/backward
loss = model(batch)
scaler.scale(loss).backward()
scaler.step(optimizer) # FP32 optimizer update
scaler.update()Benefit: - 2× memory reduction (FP16 vs FP32) - 2-3× speedup on tensor cores - Communication bandwidth halved
Challenge: Numerical stability (gradient scaling addresses this)
20.8.3 Activation Checkpointing
Trade compute for memory:
from torch.utils.checkpoint import checkpoint
class TransformerBlock(nn.Module):
def forward(self, x):
# Don't store activations during forward
# Recompute during backward
return checkpoint(self._forward, x)Memory saved: Activations for checkpointed layers
Compute cost: ~33% slowdown (recomputation during backward)
When to use: When activation memory is the bottleneck
20.9 Key Tradeoffs
| Approach | Memory/GPU | Communication | Scaling | Complexity |
|---|---|---|---|---|
| Naive DDP | O(P) | Low | Good (8-64) | Low |
| ZeRO/FSDP | O(P/N) | Medium | Good (64-512) | Medium |
| Tensor Parallel | O(P/N) | High | Limited (2-8) | Medium |
| Pipeline Parallel | O(P/N) | Low | Excellent | High |
| 3D Parallel | O(P/TNP) | Medium | Excellent | Very High |
Where: - P = total parameters - N = number of GPUs - T = tensor parallel degree - NP = pipeline parallel degree
20.10 Connections
Chapter 3 (Parallelism): Amdahl’s Law governs distributed training—communication is the sequential portion that limits scaling.
Chapter 7 (Locality): Keeping data on-device vs. communicating it is a locality decision at distributed scale.
Chapter 8 (Fusion): Fusing operations reduces activation memory, which enables larger microbatches in pipeline parallelism.
20.11 Key Takeaways
No one-size-fits-all: Choose parallelism strategy based on model size, GPU count, and interconnect.
Memory vs. communication: ZeRO/FSDP trade memory for communication. Worth it when memory-bound.
3D parallelism for scale: Largest models need all three parallelism dimensions simultaneously.
Interconnect matters: Place communication-heavy operations (tensor parallel) within nodes, not across.
Microbatches reduce bubbles: Essential for efficient pipeline parallelism.
Mixed precision is essential: 2× memory and 2× speed with careful gradient scaling.
- Shoeybi et al. (2019). “Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism”
- Rajbhandari et al. (2020). “ZeRO: Memory Optimizations Toward Training Trillion Parameter Models”
- Narayanan et al. (2021). “Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM”
- PyTorch FSDP documentation: https://pytorch.org/docs/stable/fsdp.html