20  Distributed Training

When One GPU Isn’t Enough

GPT-3 has 175 billion parameters. At FP16, that’s 350GB. The largest single GPU has 80GB of memory.

The math doesn’t work. Training large models requires distribution across multiple GPUs—and that introduces fundamental challenges in parallelism, communication, and memory.

NoteProperty Spotlight: Associativity

This chapter exploits associativity—the first property from our Algebraic Framework.

Gradient averaging is associative: \((g_1 + g_2) + g_3 = g_1 + (g_2 + g_3)\). This enables ring-allreduce, tree reductions, and arbitrary partitioning of gradient synchronization. Matrix multiplications in tensor parallelism also exploit associativity—we can split along different dimensions because the underlying algebra permits regrouping.

Without associativity, distributed training would require centralized coordination for every operation.

20.1 The Three Dimensions of Parallelism

Distributed training can parallelize along three axes:

┌─────────────────────────────────────┐
│         DATA PARALLEL               │
│  Each GPU: same model, different    │
│            batch                     │
└─────────────────────────────────────┘

┌─────────────────────────────────────┐
│        TENSOR PARALLEL              │
│  Each GPU: part of each layer       │
└─────────────────────────────────────┘

┌─────────────────────────────────────┐
│       PIPELINE PARALLEL             │
│  Each GPU: different layers         │
└─────────────────────────────────────┘

Each has different tradeoffs. Most large-scale training uses all three simultaneously.

20.2 Data Parallelism

The simplest approach: replicate the model on each GPU, give each a different batch.

# Conceptually:
# GPU 0: model(batch_0)
# GPU 1: model(batch_1)
# GPU 2: model(batch_2)
# GPU 3: model(batch_3)

# After forward + backward:
# Synchronize gradients across all GPUs
# Update model parameters

20.2.1 Naive Data Parallelism

import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP

model = MyTransformer()
model = DDP(model)

for batch in dataloader:
    loss = model(batch)
    loss.backward()  # Backward pass
    # DDP automatically all-reduces gradients here
    optimizer.step()  # Update weights

Memory requirement: Full model on each GPU.

Communication: Gradients all-reduced after each microbatch.

Scaling: Near-linear up to communication becoming a bottleneck.

20.2.2 The Communication Bottleneck

For a model with P parameters:

  • Forward/backward compute: O(P × batch_size)
  • Gradient communication: O(P)

As you add GPUs: - Compute per GPU decreases (smaller batches) - Communication per GPU stays constant (still need all gradients)

Eventually, communication dominates.

Example: GPT-3-scale model (175B parameters, FP16 = 350GB gradients):

8 GPUs:  350GB / 8-GPU NVLink bandwidth (600 GB/s) = 0.58s
64 GPUs: 350GB / 64-GPU IB bandwidth (200 GB/s) = 1.75s

At 64 GPUs, if your batch takes 2 seconds to compute, you spend nearly half your time in communication.

20.2.3 ZeRO: Zero Redundancy Optimizer

The observation: we don’t need the full model on every GPU if we’re clever about when we materialize it.

ZeRO Stage 1: Partition optimizer states - Adam has 2× parameters in momentum/variance - 175B params → 350GB + 700GB optimizer state = 1050GB - Split optimizer state across GPUs - Memory per GPU: 350GB + (700GB / N)

ZeRO Stage 2: Partition gradients too - Gradients are only needed for the parameters you’re updating - Memory per GPU: 350GB + (700GB / N) + (350GB / N)

ZeRO Stage 3: Partition parameters - Each GPU only stores a shard of model parameters - All-gather needed parameters just-in-time during forward/backward - Memory per GPU: (350GB / N) + (700GB / N) + (350GB / N) = 1400GB / N

For 16 GPUs: - Naive DDP: 1050GB per GPU (doesn’t fit) - ZeRO-3: 87.5GB per GPU (fits!)

The cost: more communication (all-gather during forward, reduce-scatter during backward).

20.2.4 FSDP: Fully Sharded Data Parallelism

PyTorch’s implementation of ZeRO-3 concepts:

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

model = MyTransformer()
model = FSDP(model,
             sharding_strategy=ShardingStrategy.FULL_SHARD,
             mixed_precision=MixedPrecisionPolicy(...))

# Training loop looks identical
for batch in dataloader:
    loss = model(batch)
    loss.backward()
    optimizer.step()

FSDP automatically: - Shards parameters across GPUs - All-gathers parameters just-in-time for forward - Discards full parameters after use - Reduces gradients across GPUs - Shards gradients for optimizer update

Tradeoff: Memory for communication.

20.3 Tensor Parallelism

Instead of replicating the model, split each tensor across GPUs.

For a linear layer \(Y = XW\) where \(W\) is [d_{in} × d_{out}]:

Row-parallel: Split W row-wise

GPU 0: Y_0 = X @ W_0  (W_0 is [d_in × d_out/4])
GPU 1: Y_1 = X @ W_1
GPU 2: Y_2 = X @ W_2
GPU 3: Y_3 = X @ W_3

Output: concatenate([Y_0, Y_1, Y_2, Y_3])

Column-parallel: Split W column-wise

GPU 0: Y = X_0 @ W  (X_0 is [..., d_in/4])
GPU 1: Y = X_1 @ W
GPU 2: Y = X_2 @ W
GPU 3: Y = X_3 @ W

Output: Y_0 + Y_1 + Y_2 + Y_3  (all-reduce)

20.3.1 Megatron-Style Tensor Parallelism

For a transformer block:

           ┌─────────────────┐
Input ───> │  Attention QKV  │ (column-parallel)
           └─────────────────┘
                 ↓ (all-gather)
           ┌─────────────────┐
           │ Attention Compute│ (local)
           └─────────────────┘
                 ↓
           ┌─────────────────┐
           │  Attention Out  │ (row-parallel)
           └─────────────────┘
                 ↓ (all-reduce)
           ┌─────────────────┐
           │       MLP 1     │ (column-parallel)
           └─────────────────┘
                 ↓ (all-gather)
           ┌─────────────────┐
           │     GeLU        │ (local)
           └─────────────────┘
                 ↓
           ┌─────────────────┐
           │       MLP 2     │ (row-parallel)
           └─────────────────┘
                 ↓ (all-reduce)
           Output

Two all-reduces per transformer block.

Memory: Model size / N per GPU

Communication: 2 all-reduces per block × activations

Scaling: Limited by communication bandwidth. Typically 2-8 GPUs.

20.4 Pipeline Parallelism

Split the model vertically: different layers on different GPUs.

GPU 0: Layers 0-7    ─→  GPU 1: Layers 8-15
                          ↓
GPU 3: Layers 24-31  ←─  GPU 2: Layers 16-23

20.4.1 The Naive Problem: Bubbles

Time ─→

GPU 0: [F1][F2][F3][F4]                [B4][B3][B2][B1]
GPU 1:     [F1][F2][F3][F4]        [B4][B3][B2][B1]
GPU 2:         [F1][F2][F3][F4][B4][B3][B2][B1]
GPU 3:             [F1][F2][F3][F4][B3][B2][B1]

       └─ bubbles ─┘              └─ bubbles ─┘

Large sections where GPUs are idle (“pipeline bubbles”).

Utilization: ~50% for 4-stage pipeline

20.4.2 GPipe: Microbatches

Split each batch into microbatches:

GPU 0: [F1][F2][F3][F4][F5][F6][F7][F8]...[B8][B7]...[B1]
GPU 1:     [F1][F2][F3][F4][F5][F6][F7]...[B7][B6]...[B1]
GPU 2:         [F1][F2][F3][F4][F5][F6]...[B6][B5]...[B1]
GPU 3:             [F1][F2][F3][F4][F5]...[B5][B4]...[B1]

       └─small─┘                     └─small─┘

With M microbatches and N pipeline stages:

Bubble time fraction: \(\frac{N - 1}{M + N - 1}\)

For M = 32, N = 4: Bubble = 3/35 = 8.6%

Tradeoff: More microbatches → less bubble time, but activation memory grows linearly with microbatches.

20.4.3 PipeDream: 1F1B Schedule

Alternate forward and backward passes:

GPU 0: [F1][F2][F3][F4][B1][F5][B2][F6][B3][F7][B4]...
GPU 1:     [F1][F2][F3][B1][F4][B2][F5][B3][F6][B4]...
GPU 2:         [F1][F2][B1][F3][B2][F4][B3][F5][B4]...
GPU 3:             [F1][B1][F2][B2][F3][B3][F4][B4]...

Key insight: Release activation memory as soon as backward completes.

Memory: Only store activations for in-flight microbatches (constant with respect to total microbatches).

20.5 Combining All Three: 3D Parallelism

Real large-scale training uses data parallel + tensor parallel + pipeline parallel simultaneously.

Example: Training GPT-3

Model: 175B parameters
Hardware: 1024 A100 GPUs (80GB each)

Strategy:
- Pipeline parallel: 8 stages (different layers)
- Tensor parallel: 8-way (within each stage)
- Data parallel: 16-way (across pipeline replicas)

8 × 8 × 16 = 1024 GPUs ✓

Per-GPU memory:
- Model: 175B / (8 × 8) = 2.7B params × 2 bytes = 5.4GB
- Optimizer (ADAM): 5.4GB × 3 = 16.2GB
- Activations: ~40GB (depends on microbatch size)
- Total: ~62GB (fits in 80GB)

20.6 Communication Primitives

All distributed training relies on collective operations:

20.6.1 All-Reduce

GPU 0: [a, b, c]
GPU 1: [d, e, f]  ─→  All-Reduce  ─→  GPU 0: [a+d, b+e, c+f]
GPU 2: [g, h, i]                       GPU 1: [a+d, b+e, c+f]
                                        GPU 2: [a+d, b+e, c+f]

Used for: Gradient synchronization in data parallelism

Time complexity: O(N × data_size / bandwidth) with ring algorithm

20.6.2 All-Gather

GPU 0: [a, b]
GPU 1: [c, d]  ─→  All-Gather  ─→  GPU 0: [a, b, c, d, e, f]
GPU 2: [e, f]                       GPU 1: [a, b, c, d, e, f]
                                     GPU 2: [a, b, c, d, e, f]

Used for: Materializing full parameters in ZeRO/FSDP

20.6.3 Reduce-Scatter

GPU 0: [a, b, c, d, e, f]
GPU 1: [g, h, i, j, k, l]  ─→  Reduce-Scatter  ─→  GPU 0: [a+g, b+h]
GPU 2: [m, n, o, p, q, r]                           GPU 1: [c+i, d+j]
                                                     GPU 2: [e+k, f+l]

Used for: Distributing gradients in ZeRO/FSDP

20.7 The Interconnect Hierarchy

Communication speed depends on which GPUs are communicating:

NVLink (within a node): - 8× A100: 600 GB/s bidirectional per GPU - Very fast, all-to-all connectivity

InfiniBand (across nodes): - 200-800 Gb/s per node (25-100 GB/s) - Higher latency than NVLink

Strategy: Minimize cross-node communication - Tensor parallel: within nodes (tight, high-bandwidth requirement) - Pipeline parallel: across nodes (sequential, latency-tolerant) - Data parallel: across nodes (large, infrequent all-reduce)

20.8 Practical Considerations

20.8.1 Gradient Accumulation

To train with larger effective batch sizes:

accumulation_steps = 4

for i, batch in enumerate(dataloader):
    loss = model(batch) / accumulation_steps
    loss.backward()  # Accumulate gradients

    if (i + 1) % accumulation_steps == 0:
        optimizer.step()  # Update weights
        optimizer.zero_grad()

Memory: Same as smaller batch (only activations for one microbatch)

Compute: Same total FLOPs as large batch

Tradeoff: Slightly less efficient (more overhead), but enables training with limited memory

20.8.2 Mixed Precision

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for batch in dataloader:
    with autocast():  # FP16 forward/backward
        loss = model(batch)

    scaler.scale(loss).backward()
    scaler.step(optimizer)  # FP32 optimizer update
    scaler.update()

Benefit: - 2× memory reduction (FP16 vs FP32) - 2-3× speedup on tensor cores - Communication bandwidth halved

Challenge: Numerical stability (gradient scaling addresses this)

20.8.3 Activation Checkpointing

Trade compute for memory:

from torch.utils.checkpoint import checkpoint

class TransformerBlock(nn.Module):
    def forward(self, x):
        # Don't store activations during forward
        # Recompute during backward
        return checkpoint(self._forward, x)

Memory saved: Activations for checkpointed layers

Compute cost: ~33% slowdown (recomputation during backward)

When to use: When activation memory is the bottleneck

20.9 Key Tradeoffs

Approach Memory/GPU Communication Scaling Complexity
Naive DDP O(P) Low Good (8-64) Low
ZeRO/FSDP O(P/N) Medium Good (64-512) Medium
Tensor Parallel O(P/N) High Limited (2-8) Medium
Pipeline Parallel O(P/N) Low Excellent High
3D Parallel O(P/TNP) Medium Excellent Very High

Where: - P = total parameters - N = number of GPUs - T = tensor parallel degree - NP = pipeline parallel degree

20.10 Connections

Chapter 3 (Parallelism): Amdahl’s Law governs distributed training—communication is the sequential portion that limits scaling.

Chapter 7 (Locality): Keeping data on-device vs. communicating it is a locality decision at distributed scale.

Chapter 8 (Fusion): Fusing operations reduces activation memory, which enables larger microbatches in pipeline parallelism.

20.11 Key Takeaways

  1. No one-size-fits-all: Choose parallelism strategy based on model size, GPU count, and interconnect.

  2. Memory vs. communication: ZeRO/FSDP trade memory for communication. Worth it when memory-bound.

  3. 3D parallelism for scale: Largest models need all three parallelism dimensions simultaneously.

  4. Interconnect matters: Place communication-heavy operations (tensor parallel) within nodes, not across.

  5. Microbatches reduce bubbles: Essential for efficient pipeline parallelism.

  6. Mixed precision is essential: 2× memory and 2× speed with careful gradient scaling.

NoteFurther Reading
  • Shoeybi et al. (2019). “Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism”
  • Rajbhandari et al. (2020). “ZeRO: Memory Optimizations Toward Training Trillion Parameter Models”
  • Narayanan et al. (2021). “Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM”
  • PyTorch FSDP documentation: https://pytorch.org/docs/stable/fsdp.html