20  Distributed Training

When One GPU Isn’t Enough

GPT-3 has 175 billion parameters. At FP16, that’s 350GB. The largest single GPU has 80GB of memory.

The math doesn’t work. Training large models requires distribution across multiple GPUs—and that introduces fundamental challenges in parallelism, communication, and memory.

NoteProperty Spotlight: Associativity

This chapter exploits associativity—the first property from our Algebraic Framework.

Gradient averaging is associative: \((g_1 + g_2) + g_3 = g_1 + (g_2 + g_3)\). This enables ring-allreduce, tree reductions, and arbitrary partitioning of gradient synchronization. Matrix multiplications in tensor parallelism also exploit associativity—we can split along different dimensions because the underlying algebra permits regrouping.

Without associativity, distributed training would require centralized coordination for every operation.

20.1 The Three Dimensions of Parallelism

Distributed training can parallelize along three axes:

┌─────────────────────────────────────┐
│         DATA PARALLEL               │
│  Each GPU: same model, different    │
│            batch                     │
└─────────────────────────────────────┘

┌─────────────────────────────────────┐
│        TENSOR PARALLEL              │
│  Each GPU: part of each layer       │
└─────────────────────────────────────┘

┌─────────────────────────────────────┐
│       PIPELINE PARALLEL             │
│  Each GPU: different layers         │
└─────────────────────────────────────┘

Each has different tradeoffs. Most large-scale training uses all three simultaneously.

20.2 Data Parallelism

The simplest approach: replicate the model on each GPU, give each a different batch.

# Conceptually:
# GPU 0: model(batch_0)
# GPU 1: model(batch_1)
# GPU 2: model(batch_2)
# GPU 3: model(batch_3)

# After forward + backward:
# Synchronize gradients across all GPUs
# Update model parameters

20.2.1 Naive Data Parallelism

import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP

model = MyTransformer()
model = DDP(model)

for batch in dataloader:
    loss = model(batch)
    loss.backward()  # Backward pass
    # DDP automatically all-reduces gradients here
    optimizer.step()  # Update weights

Memory requirement: Full model on each GPU.

Communication: Gradients all-reduced after each microbatch.

Scaling: Near-linear up to communication becoming a bottleneck.

20.2.2 The Communication Bottleneck

For a model with P parameters:

  • Forward/backward compute: O(P × batch_size)
  • Gradient communication: O(P)

As you add GPUs: - Compute per GPU decreases (smaller batches) - Communication per GPU stays constant (still need all gradients)

Eventually, communication dominates.

Example: GPT-3-scale model (175B parameters, FP16 = 350GB gradients):

8 GPUs:  350GB / 8-GPU NVLink bandwidth (600 GB/s) = 0.58s
64 GPUs: 350GB / 64-GPU IB bandwidth (200 GB/s) = 1.75s

At 64 GPUs, if your batch takes 2 seconds to compute, you spend nearly half your time in communication.

20.2.3 ZeRO: Zero Redundancy Optimizer

The observation: we don’t need the full model on every GPU if we’re clever about when we materialize it.

ZeRO Stage 1: Partition optimizer states - Adam has 2× parameters in momentum/variance - 175B params → 350GB + 700GB optimizer state = 1050GB - Split optimizer state across GPUs - Memory per GPU: 350GB + (700GB / N)

ZeRO Stage 2: Partition gradients too - Gradients are only needed for the parameters you’re updating - Memory per GPU: 350GB + (700GB / N) + (350GB / N)

ZeRO Stage 3: Partition parameters - Each GPU only stores a shard of model parameters - All-gather needed parameters just-in-time during forward/backward - Memory per GPU: (350GB / N) + (700GB / N) + (350GB / N) = 1400GB / N

For 16 GPUs: - Naive DDP: 1050GB per GPU (doesn’t fit) - ZeRO-3: 87.5GB per GPU (fits!)

The cost: more communication (all-gather during forward, reduce-scatter during backward).

20.2.4 FSDP: Fully Sharded Data Parallelism

PyTorch’s implementation of ZeRO-3 concepts:

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

model = MyTransformer()
model = FSDP(model,
             sharding_strategy=ShardingStrategy.FULL_SHARD,
             mixed_precision=MixedPrecisionPolicy(...))

# Training loop looks identical
for batch in dataloader:
    loss = model(batch)
    loss.backward()
    optimizer.step()

FSDP automatically: - Shards parameters across GPUs - All-gathers parameters just-in-time for forward - Discards full parameters after use - Reduces gradients across GPUs - Shards gradients for optimizer update

Tradeoff: Memory for communication.

20.3 Tensor Parallelism

Instead of replicating the model, split each tensor across GPUs.

For a linear layer \(Y = XW\) where \(W\) is [d_{in} × d_{out}]:

Row-parallel: Split W row-wise

GPU 0: Y_0 = X @ W_0  (W_0 is [d_in × d_out/4])
GPU 1: Y_1 = X @ W_1
GPU 2: Y_2 = X @ W_2
GPU 3: Y_3 = X @ W_3

Output: concatenate([Y_0, Y_1, Y_2, Y_3])

Column-parallel: Split W column-wise

GPU 0: Y = X_0 @ W  (X_0 is [..., d_in/4])
GPU 1: Y = X_1 @ W
GPU 2: Y = X_2 @ W
GPU 3: Y = X_3 @ W

Output: Y_0 + Y_1 + Y_2 + Y_3  (all-reduce)

20.3.1 Megatron-Style Tensor Parallelism

For a transformer block:

           ┌─────────────────┐
Input ───> │  Attention QKV  │ (column-parallel)
           └─────────────────┘
                 ↓ (all-gather)
           ┌─────────────────┐
           │ Attention Compute│ (local)
           └─────────────────┘
                 ↓
           ┌─────────────────┐
           │  Attention Out  │ (row-parallel)
           └─────────────────┘
                 ↓ (all-reduce)
           ┌─────────────────┐
           │       MLP 1     │ (column-parallel)
           └─────────────────┘
                 ↓ (all-gather)
           ┌─────────────────┐
           │     GeLU        │ (local)
           └─────────────────┘
                 ↓
           ┌─────────────────┐
           │       MLP 2     │ (row-parallel)
           └─────────────────┘
                 ↓ (all-reduce)
           Output

Two all-reduces per transformer block.

Memory: Model size / N per GPU

Communication: 2 all-reduces per block × activations

Scaling: Limited by communication bandwidth. Typically 2-8 GPUs.

20.4 Pipeline Parallelism

Split the model vertically: different layers on different GPUs.

GPU 0: Layers 0-7    ─→  GPU 1: Layers 8-15
                          ↓
GPU 3: Layers 24-31  ←─  GPU 2: Layers 16-23

20.4.1 The Naive Problem: Bubbles

Time ─→

GPU 0: [F1][F2][F3][F4]                [B4][B3][B2][B1]
GPU 1:     [F1][F2][F3][F4]        [B4][B3][B2][B1]
GPU 2:         [F1][F2][F3][F4][B4][B3][B2][B1]
GPU 3:             [F1][F2][F3][F4][B3][B2][B1]

       └─ bubbles ─┘              └─ bubbles ─┘

Large sections where GPUs are idle (“pipeline bubbles”).

Utilization: ~50% for 4-stage pipeline

20.4.2 GPipe: Microbatches

Split each batch into microbatches:

GPU 0: [F1][F2][F3][F4][F5][F6][F7][F8]...[B8][B7]...[B1]
GPU 1:     [F1][F2][F3][F4][F5][F6][F7]...[B7][B6]...[B1]
GPU 2:         [F1][F2][F3][F4][F5][F6]...[B6][B5]...[B1]
GPU 3:             [F1][F2][F3][F4][F5]...[B5][B4]...[B1]

       └─small─┘                     └─small─┘

With M microbatches and N pipeline stages:

Bubble time fraction: \(\frac{N - 1}{M + N - 1}\)

For M = 32, N = 4: Bubble = 3/35 = 8.6%

Tradeoff: More microbatches → less bubble time, but activation memory grows linearly with microbatches.

20.4.3 PipeDream: 1F1B Schedule

Alternate forward and backward passes:

GPU 0: [F1][F2][F3][F4][B1][F5][B2][F6][B3][F7][B4]...
GPU 1:     [F1][F2][F3][B1][F4][B2][F5][B3][F6][B4]...
GPU 2:         [F1][F2][B1][F3][B2][F4][B3][F5][B4]...
GPU 3:             [F1][B1][F2][B2][F3][B3][F4][B4]...

Key insight: Release activation memory as soon as backward completes.

Memory: Only store activations for in-flight microbatches (constant with respect to total microbatches).

20.5 Combining All Three: 3D Parallelism

Real large-scale training uses data parallel + tensor parallel + pipeline parallel simultaneously.

Example: Training GPT-3

Model: 175B parameters
Hardware: 1024 A100 GPUs (80GB each)

Strategy:
- Pipeline parallel: 8 stages (different layers)
- Tensor parallel: 8-way (within each stage)
- Data parallel: 16-way (across pipeline replicas)

8 × 8 × 16 = 1024 GPUs ✓

Per-GPU memory:
- Model: 175B / (8 × 8) = 2.7B params × 2 bytes = 5.4GB
- Optimizer (AdamW): 2.7B params × 12 bytes = 32.4GB
  (FP32 master weights: 10.8GB + momentum: 10.8GB + variance: 10.8GB)
- Activations: ~40GB (depends on microbatch size)
- Total: ~78GB (tight fit in 80GB, requires activation checkpointing)

20.6 Communication Primitives

All distributed training relies on collective operations:

20.6.1 All-Reduce

GPU 0: [a, b, c]
GPU 1: [d, e, f]  ─→  All-Reduce  ─→  GPU 0: [a+d, b+e, c+f]
GPU 2: [g, h, i]                       GPU 1: [a+d, b+e, c+f]
                                        GPU 2: [a+d, b+e, c+f]

Used for: Gradient synchronization in data parallelism

Time complexity: O(data_size / bandwidth) with ring algorithm — the key insight of ring all-reduce is that the cost is independent of N (each GPU sends and receives 2(N-1)/N of the data, which approaches 2× data_size as N grows)

20.6.2 All-Gather

GPU 0: [a, b]
GPU 1: [c, d]  ─→  All-Gather  ─→  GPU 0: [a, b, c, d, e, f]
GPU 2: [e, f]                       GPU 1: [a, b, c, d, e, f]
                                     GPU 2: [a, b, c, d, e, f]

Used for: Materializing full parameters in ZeRO/FSDP

20.6.3 Reduce-Scatter

GPU 0: [a, b, c, d, e, f]
GPU 1: [g, h, i, j, k, l]  ─→  Reduce-Scatter  ─→  GPU 0: [a+g, b+h]
GPU 2: [m, n, o, p, q, r]                           GPU 1: [c+i, d+j]
                                                     GPU 2: [e+k, f+l]

Used for: Distributing gradients in ZeRO/FSDP

20.7 The Interconnect Hierarchy

Communication speed depends on which GPUs are communicating:

NVLink (within a node): - 8× A100 (NVLink 3): ~600 GB/s bidirectional per GPU (aggregate across links) - Very fast, all-to-all connectivity

InfiniBand (across nodes): - 200-800 Gb/s aggregate per node depending on NIC count (25-100 GB/s) - Higher latency than NVLink

Strategy: Minimize cross-node communication - Tensor parallel: within nodes (tight, high-bandwidth requirement) - Pipeline parallel: across nodes (sequential, latency-tolerant) - Data parallel: across nodes (large, infrequent all-reduce)

20.8 The Network Layer’s Hidden Assumptions

Hardware bandwidth isn’t the only communication bottleneck. The network stack itself embeds assumptions that break for ML workloads.

20.8.1 The Fairness Paradox

Consider TCP congestion control. Its core goal is fairness—dividing bandwidth equally among competing flows. This is universally considered good behavior. For web traffic, where flows have independent, competing goals, it is.

DNN training has different goals. Collective operations like all-reduce are synchronized—all workers must complete before any can proceed to the next iteration. When multiple training jobs share a cluster, fairness becomes the enemy.

Why fairness hurts:

Training jobs communicate in bursts at iteration boundaries. When multiple jobs hit their communication phase simultaneously:

  1. Standard congestion control sees competing flows
  2. It fairly divides bandwidth among all flows
  3. All jobs slow down together
  4. The slowest job determines iteration time for everyone

The result: congestion precisely when it hurts most, at synchronization points.

20.8.2 Training-Aware Congestion Control

The fix isn’t more bandwidth—it’s recognizing that training flows are cooperative, not competitive.

MLTCP [1] augments standard congestion control with a simple insight: scale congestion windows based on bytes sent per training iteration, not per-flow fairness.

Standard congestion control:
  Flow A: [========]─────────────────
  Flow B: [========]─────────────────
          ↑ Both slow, both synchronized → worst case

Training-aware congestion control:
  Flow A: [================]─────────
  Flow B: ─────────[================]
          ↑ Interleaved → both fast

By allowing flows to interleave rather than collide, training-aware protocols achieve 2-4× speedup on the same hardware.

Results from MLTCP:

  • Average training iteration time: up to 2× faster
  • 99th percentile iteration time: up to 4× faster
  • Stabilizes competing flows within a few iterations

20.8.3 The Broader Lesson

Every layer of the stack encodes assumptions:

  • Congestion control assumes flows have independent, competing goals
  • Schedulers assume tasks want fair CPU time
  • Memory allocators assume allocation patterns are unpredictable
  • Caches assume recently-accessed data will be accessed again

When performance disappoints despite adequate hardware, ask: what assumption does this layer make that my workload violates?

For distributed ML training, the network layer’s assumption of competitive flows is often the hidden variable limiting performance.

20.9 Practical Considerations

20.9.1 Gradient Accumulation

To train with larger effective batch sizes:

accumulation_steps = 4

for i, batch in enumerate(dataloader):
    loss = model(batch) / accumulation_steps
    loss.backward()  # Accumulate gradients

    if (i + 1) % accumulation_steps == 0:
        optimizer.step()  # Update weights
        optimizer.zero_grad()

Memory: Same as smaller batch (only activations for one microbatch)

Compute: Same total FLOPs as large batch

Tradeoff: Slightly less efficient (more overhead), but enables training with limited memory

20.9.2 Mixed Precision

from torch.amp import autocast, GradScaler

scaler = GradScaler()

for batch in dataloader:
    with autocast('cuda'):  # FP16 forward/backward
        loss = model(batch)

    scaler.scale(loss).backward()
    scaler.step(optimizer)  # FP32 optimizer update
    scaler.update()

Benefit: - 2× memory reduction (FP16 vs FP32) - 2-3× speedup on tensor cores - Communication bandwidth halved

Challenge: Numerical stability (gradient scaling addresses this)

20.9.3 Activation Checkpointing

Trade compute for memory:

from torch.utils.checkpoint import checkpoint

class TransformerBlock(nn.Module):
    def forward(self, x):
        # Don't store activations during forward
        # Recompute during backward
        return checkpoint(self._forward, x)

Memory saved: Activations for checkpointed layers

Compute cost: ~33% slowdown (recomputation during backward)

When to use: When activation memory is the bottleneck

20.10 Key Tradeoffs

Approach Memory/GPU Communication Scaling Complexity
Naive DDP O(P) Low Good (8-64) Low
ZeRO/FSDP O(P/N) Medium Good (64-512) Medium
Tensor Parallel O(P/N) High Limited (2-8) Medium
Pipeline Parallel O(P/N) Low Excellent High
3D Parallel O(P/TNP) Medium Excellent Very High

Where: - P = total parameters - N = number of GPUs - T = tensor parallel degree - NP = pipeline parallel degree

20.11 Connections

Parallelism: Amdahl’s Law governs distributed training—communication is the sequential portion that limits scaling.

Locality: Keeping data on-device vs. communicating it is a locality decision at distributed scale.

Fusion: Fusing operations reduces activation memory, which enables larger microbatches in pipeline parallelism.

20.12 Key Takeaways

  1. No one-size-fits-all: Choose parallelism strategy based on model size, GPU count, and interconnect.

  2. Memory vs. communication: ZeRO/FSDP trade memory for communication. Worth it when memory-bound.

  3. 3D parallelism for scale: Largest models need all three parallelism dimensions simultaneously.

  4. Interconnect matters: Place communication-heavy operations (tensor parallel) within nodes, not across.

  5. Microbatches reduce bubbles: Essential for efficient pipeline parallelism.

  6. Mixed precision is essential: 2× memory and 2× speed with careful gradient scaling.

  7. Question every layer’s assumptions: The network stack assumes competitive flows; training jobs are cooperative. When performance disappoints despite good hardware, look for assumption mismatches.

NoteFurther Reading
  • Distributed Training from First Principles - A comprehensive deep-dive into distributed training, covering everything from data parallelism to expert parallelism
  • Shoeybi et al. (2019). “Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism”
  • Rajbhandari et al. (2020). “ZeRO: Memory Optimizations Toward Training Trillion Parameter Models”
  • Narayanan et al. (2021). “Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM”
  • Rajasekaran et al. (2022). “MLTCP: Congestion Control for DNN Training” - Training-aware network protocols
  • PyTorch FSDP documentation: https://pytorch.org/docs/stable/fsdp.html
[1]
S. Rajasekaran, M. Narang, S. Kadekodi, and A. Akella, “MLTCP: Congestion control for DNN training,” in Proceedings of the 21st ACM workshop on hot topics in networks (HotNets), 2022.