Batch Size and Learning Dynamics
Batch size is not just a memory parameter—it fundamentally affects learning dynamics. There's a critical batch size beyond which returns diminish rapidly. Understanding this is essential for scaling to thousands of GPUs.
The Question: You want to scale from 8 GPUs to 10,000 GPUs. The naive approach: increase batch size 1,250×. But models trained with batch size 1M often fail to converge. What's the limit, and how do we push past it?
Gradient Noise and Batch Size¶
Each minibatch provides a noisy estimate of the true gradient:
Where \(\epsilon_B\) is the noise with variance:
Here \(\sigma^2\) is the per-sample gradient variance (a property of the data and model). Larger batch → lower noise → more reliable gradient direction.
But here's the key insight: noise isn't always bad.
The Beneficial Role of Noise¶
Gradient noise: 1. Helps escape sharp minima (which generalize poorly) 2. Provides implicit regularization 3. Enables exploration of loss landscape
Too little noise → may converge to sharp, non-generalizing minima.
The Critical Batch Size¶
McCandlish et al. (2018) derived the critical batch size \(B_{\text{crit}}\):
Where:
- \(G^2 = \mathbb{E}[||\nabla L||^2]\): expected gradient norm squared
- \(H = \mathbb{E}[(\nabla L)^T \nabla^2 L (\nabla L)]\): curvature along gradient
Equivalently, using the noise scale:
Where \(\Sigma\) is the gradient covariance.
Interpretation¶
-
\(B < B_{\text{crit}}\): Training is noise-dominated. Each step is small due to noisy gradients. Need many steps.
-
\(B \approx B_{\text{crit}}\): Optimal balance. Steps are reliable but not wasting compute.
-
\(B > B_{\text{crit}}\): Training is curvature-dominated. Extra samples provide diminishing returns. Compute is wasted.
Empirical Values¶
Critical batch size varies by task and training stage:
| Domain | Typical \(B_{\text{crit}}\) |
|---|---|
| ImageNet (early training) | 2K - 8K |
| ImageNet (late training) | 16K - 64K |
| Language models (small) | 256 - 2K |
| Language models (large) | 2M - 8M |
Note: \(B_{\text{crit}}\) increases during training as the model approaches a minimum and curvature decreases.
The Perfect Scaling Law¶
Below \(B_{\text{crit}}\), training scales perfectly:
Where \(S_0\) is steps at baseline batch \(B_0\).
Doubling batch size → halving steps → same wall-clock time per step → 2× faster training.
Total compute stays constant:
The Diminishing Returns Regime¶
Above \(B_{\text{crit}}\), the relationship becomes:
Where \(S_{\min}\) is the minimum steps regardless of batch size (curvature limit).
As \(B \to \infty\):
You can't reduce steps below \(S_{\min}\) no matter how large the batch.
Compute waste:
Learning Rate Scaling¶
When increasing batch size, you must adjust learning rate. The question is: how?
Linear Scaling Rule (Goyal et al., 2017)¶
Intuition: If batch size doubles, the gradient is twice as reliable, so we can take twice as large a step.
Valid when:
- \(B \leq B_{\text{crit}}\)
- Using SGD with momentum
Derivation: Consider the update over \(k\) steps with batch \(B_0\) vs 1 step with batch \(kB_0\):
Small batch:
Large batch:
For equivalence: \(\eta' = k\eta_0\)
Square Root Scaling¶
Intuition: The noise in the gradient scales as \(1/\sqrt{B}\), so learning rate should scale with noise reduction.
Valid when:
- Beyond \(B_{\text{crit}}\)
- Loss landscape is more complex
Derivation: From the perspective of SGD convergence rate in convex optimization, the optimal learning rate is \(\eta \propto 1/\sqrt{B}\) for noisy gradients.
Which to Use?¶
| Regime | Scaling Rule |
|---|---|
| \(B \ll B_{\text{crit}}\) | Linear |
| \(B \approx B_{\text{crit}}\) | Between linear and sqrt |
| \(B > B_{\text{crit}}\) | Square root |
| \(B \gg B_{\text{crit}}\) | Constant (no benefit to increasing) |
Practical approach: linear scaling with warmup, then reduce if instability.
Warmup¶
Large learning rates at the start of training cause divergence. Solution: warmup.
Linear Warmup¶
for \(t \leq T_{\text{warmup}}\).
Why Warmup Helps¶
Early in training: 1. Gradients are large and noisy 2. Loss landscape curvature is high 3. Model is far from any minimum
Large steps cause:
- Gradient explosion
- Catastrophic updates
- Divergence
Warmup allows the model to "find its footing" before taking large steps.
Warmup Duration¶
Rule of thumb:
Where \(T_0\) is warmup steps at baseline batch.
For very large batches (>64K), longer warmup may be needed.
Layer-wise Adaptive Learning Rates¶
Different layers have different gradient magnitudes. Standard learning rate works poorly for very deep or very wide networks at large batch sizes.
LARS (You et al., 2017)¶
Layer-wise Adaptive Rate Scaling for SGD:
Where the trust ratio is:
Here \(\lambda\) is the weight decay coefficient (typically 0.0001 to 0.001).
Intuition: Scale the learning rate by the ratio of weight norm to gradient norm. Prevents any layer from updating too much relative to its current scale.
The update becomes:
LARS enabled training ImageNet with batch size 32K in 1 hour (vs. days with standard SGD).
LAMB (You et al., 2019)¶
Layer-wise Adaptive Moments for Batch training—combines LARS with Adam:
Where \(\epsilon\) is a small constant for numerical stability (typically \(10^{-8}\)).
Then apply LARS-style trust ratio:
LAMB enabled training BERT with batch size 65K in 76 minutes (vs. 3 days with Adam).
Comparison¶
| Method | Base Optimizer | Max Batch (ImageNet) | Max Batch (BERT) |
|---|---|---|---|
| SGD + Linear LR | SGD | ~8K | N/A |
| LARS | SGD | 32K | ~8K |
| Adam + Linear LR | Adam | ~16K | ~16K |
| LAMB | Adam | ~32K | 65K |
The Batch Size vs. Time Trade-off¶
Larger batch sizes enable data parallelism across more GPUs. But returns diminish.
Scaling Efficiency¶
Define scaling efficiency:
- \(E = 1\): Perfect scaling (linear speedup)
- \(E < 1\): Sub-linear scaling
- \(E \to 0\): Wasted compute
Below \(B_{\text{crit}}\): \(E \approx 1\)
Above \(B_{\text{crit}}\): \(E\) drops rapidly
When to Scale¶
Scale batch when:
- Wall-clock time is the constraint
- You're below \(B_{\text{crit}}\)
- GPU utilization is high
Don't scale batch when:
- Already above \(B_{\text{crit}}\)
- Final quality matters more than speed
- Hyperparameter tuning is difficult
Gradient Accumulation¶
Can't fit large batch in memory? Accumulate gradients:
optimizer.zero_grad()
for i, batch in enumerate(mini_batches):
loss = model(batch) / accumulation_steps
loss.backward() # Accumulates gradients
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
Mathematically equivalent to large batch, but:
- More forward/backward passes
- Same memory as small batch
- Slower than true data parallelism
Use when: GPU memory limits effective batch size, but \(B_{\text{crit}}\) hasn't been reached.
Dynamic Batch Sizing¶
\(B_{\text{crit}}\) increases during training. Optimal strategy: increase batch size during training.
LLaMA's Approach¶
LLaMA 2 (and later LLaMA 3) increased batch size mid-training:
- Start: smaller batch size (e.g., 2M tokens for LLaMA 2)
- Ramp up to 4M tokens after initial training phase
Note: LLaMA 1 used a constant batch size of ~4M tokens throughout training. The batch size ramp was introduced in LLaMA 2.
Benefits: 1. Early training: smaller batch for exploration 2. Late training: larger batch for efficiency 3. Total steps reduced
Adaptive Scaling¶
Monitor gradient noise scale:
When noise scale drops, safe to increase batch size.
Practical Recipe for Large-Batch Training¶
-
Establish baseline: Train with small batch (256-1024), find optimal \(\eta_0\)
-
Estimate \(B_{\text{crit}}\): Double batch, check if steps halve. Stop when they don't.
-
Scale with linear rule: \(\eta = \eta_0 \cdot B/B_0\) up to \(B_{\text{crit}}\)
-
Use warmup: \(T_{\text{warmup}} \propto B/B_0\)
-
Consider LARS/LAMB: Essential for \(B > 8K\) typically
-
Monitor carefully:
-
Loss spikes → reduce LR or increase warmup
- Slow convergence → may have exceeded \(B_{\text{crit}}\)
-
Layer-wise gradient norms → check for imbalance
-
Dynamic batch: Increase batch size as training progresses
Exercises¶
- Critical batch size: A model trains in 100K steps with batch size 256. With batch size 1024, it trains in 28K steps. With batch size 4096, it trains in 15K steps. Estimate \(B_{\text{crit}}\).
Solution
Given data:
| Batch Size | Steps to Convergence |
|---|---|
| 256 | 100,000 |
| 1,024 | 28,000 |
| 4,096 | 15,000 |
Using the diminishing returns model:
Check perfect scaling from B=256 to B=1024 (4× increase):
If scaling were perfect: \(S(1024) = 100,000 / 4 = 25,000\)
Actual: 28,000 steps → already seeing diminishing returns
Set up equations:
From \(S(256) = 100,000\):
From \(S(1024) = 28,000\):
Solve:
Subtract second from first:
Substitute back:
Verify with B=4096: $\(S(4096) = 4,000 + \frac{24,576,000}{4096} = 4,000 + 6,000 = 10,000\)$
Actual: 15,000 steps (discrepancy suggests model is approximate)
Estimate \(B_{\text{crit}}\):
The critical batch size is where noise and curvature contributions are equal:
Interpretation: Batch sizes above ~6K will show significant diminishing returns. The data shows we're already past \(B_{\text{crit}}\) at 4096, confirming the estimate is in the right range.
- Learning rate scaling: You scale from batch 256 with \(\eta = 0.001\) to batch 4096. What learning rate should you use under (a) linear scaling, (b) square root scaling?
Solution
Given:
- Base batch: \(B_0 = 256\)
- Base learning rate: \(\eta_0 = 0.001\)
- Target batch: \(B = 4096\)
- Scaling factor: \(B/B_0 = 16\)
(a) Linear scaling:
(b) Square root scaling:
Which to use?
| Batch Size | Relative to \(B_{\text{crit}}\) | Recommended Scaling |
|---|---|---|
| 4,096 | Below 6K (from Exercise 1) | Linear (0.016) |
| 8,192 | Above 6K | Square root |
| 16,384 | Well above | Constant or sqrt |
Practical note: Start with linear scaling (0.016) but use warmup. If training is unstable, fall back to square root (0.004).
- Compute efficiency: With batch 512, training takes 50K steps. With batch 8192, training takes 6K steps. Calculate the scaling efficiency \(E(8192)\).
Solution
Given:
- Base batch: \(B_0 = 512\), steps \(S_0 = 50,000\)
- Target batch: \(B = 8192\), steps \(S = 6,000\)
Scaling efficiency formula:
Calculate:
Interpretation:
| Metric | Value |
|---|---|
| Batch increase | 16× |
| Step reduction | 8.33× |
| Scaling efficiency | 52% |
| "Wasted" compute | 48% |
Perfect scaling would give: $\(S_{\text{perfect}} = \frac{50,000}{16} = 3,125 \text{ steps}\)$
Actual: 6,000 steps → 1.92× more steps than perfect scaling.
Conclusion: At batch 8192, we're well past \(B_{\text{crit}}\). Almost half the compute is "wasted" in the sense that it doesn't reduce training steps. However, if wall-clock time is the constraint, this may still be worthwhile.
- LARS derivation: Show that the LARS trust ratio \(\phi = ||w||/||\nabla w||\) ensures that the relative update \(||\Delta w||/||w||\) is approximately constant across layers.
Solution
LARS update rule:
Where the trust ratio is:
(ignoring weight decay for simplicity)
Relative update magnitude:
Key insight: The relative update \(||\Delta w||/||w||\) equals \(\eta\) for all layers, regardless of:
- Weight magnitude \(||w_l||\)
- Gradient magnitude \(||\nabla w_l||\)
Why this matters:
| Layer | Without LARS | With LARS |
|---|---|---|
| Small weights, large gradients | Huge relative update | \(\eta\) |
| Large weights, small gradients | Tiny relative update | \(\eta\) |
| Any layer | $\eta \cdot |
Consequence: All layers update at the same relative rate, preventing: - Early layers from updating too slowly - Output layers from updating too aggressively - Training instability at large batch sizes
This is why LARS enables training with batch sizes of 32K+ where standard SGD fails.
- Gradient accumulation: You have 8 GPUs with batch 32 each (256 total) but need effective batch 2048. How many accumulation steps? If each forward-backward takes 100ms, and all-reduce takes 20ms, what's the time per effective step?
Solution
Given:
- GPUs: 8
- Per-GPU batch: 32
- Current effective batch: \(8 \times 32 = 256\)
- Target effective batch: 2,048
- Forward-backward time: 100ms
- All-reduce time: 20ms
Accumulation steps:
Time breakdown per effective step:
| Phase | Count | Time Each | Total |
|---|---|---|---|
| Forward-backward passes | 8 | 100ms | 800ms |
| All-reduce (only at end) | 1 | 20ms | 20ms |
| Total | 820ms |
Comparison with true data parallelism (64 GPUs):
If we had 64 GPUs with batch 32 each: - Forward-backward: 100ms - All-reduce: ~40ms (slightly higher for more GPUs) - Total: ~140ms
Efficiency comparison:
Key insight: Gradient accumulation trades time for memory. It's useful when: 1. \(B_{\text{crit}}\) hasn't been reached yet 2. GPU memory limits batch size 3. More GPUs aren't available
- Dynamic batching: You want to train for 1M tokens/step initially, ramping to 4M tokens/step. If you switch at the midpoint of training, how many fewer gradient updates do you perform compared to constant 1M tokens/step?
Solution
Setup:
Let total training be \(D\) tokens (e.g., 2T tokens).
Constant batch size (1M tokens/step):
Dynamic batch size:
- First half (\(D/2\) tokens) at 1M tokens/step:
- Second half (\(D/2\) tokens) at 4M tokens/step:
- Total dynamic updates:
Reduction:
Concrete example with D = 2T tokens:
| Strategy | Updates | Wall-clock (relative) |
|---|---|---|
| Constant 1M | 2,000,000 | 1.0× |
| Dynamic (1M→4M) | 1,250,000 | ~0.75× |
| Reduction | 750,000 | 25% faster |
Wait—why is wall-clock only 25% faster if we have 37.5% fewer updates?
The 4M batch steps take ~2× longer than 1M batch steps (more compute per step). But we gain on communication overhead being amortized over more tokens.
Actual wall-clock speedup depends on: - Communication overhead fraction - GPU utilization at each batch size - Whether we're below \(B_{\text{crit}}\) for both batch sizes
Summary: Dynamic batching reduces gradient updates by 37.5%, translating to ~20-30% wall-clock speedup depending on system characteristics.
Key Takeaways¶
-
Critical batch size exists: Beyond \(B_{\text{crit}}\), compute is wasted.
-
Linear scaling works below \(B_{\text{crit}}\): Learning rate \(\propto\) batch size.
-
Warmup is essential: Larger batches need longer warmup.
-
LARS/LAMB enable extreme scaling: Layer-wise adaptation for 30K+ batch sizes.
-
\(B_{\text{crit}}\) increases during training: Dynamic batch sizing can exploit this.
-
Wall-clock vs. compute trade-off: Larger batch is faster but less efficient above \(B_{\text{crit}}\).