38 Numerical Precision
The Hidden Variable in Every Computation
Every floating-point number is a lie. Some lies are close enough to the truth. Understanding which lies matter—and which don’t—is essential for both correctness and performance.
38.1 The Performance-Precision Trade-off
Modern hardware offers a choice: compute with more bits (accurate but slow) or fewer bits (fast but approximate). The A100 GPU demonstrates this dramatically:
NVIDIA A100 Peak Performance by Precision:
┌─────────────────────────────────────────┐
│ FP64: 9.7 TFLOPS │
│ FP32: 19.5 TFLOPS │
│ TF32: 156 TFLOPS │
│ FP16: 312 TFLOPS │
│ INT8: 624 TOPS │
│ FP8: ~1000 TFLOPS (estimated) │
└─────────────────────────────────────────┘
A 32× speedup from FP64 to FP16. A 100× speedup to INT8. But you can’t just switch precisions blindly—the math breaks in specific, predictable ways.
38.2 IEEE 754: What Your Numbers Actually Are
Before optimizing precision, you need to understand what a floating-point number is.
38.2.1 The Anatomy of a Float
IEEE 754 floats have three components:
Single Precision (FP32): 32 bits total
┌─────┬──────────────────┬─────────────────────────────────────────┐
│ S │ Exponent (8) │ Mantissa (23) │
│ 1 │ 01111100 │ 01000000000000000000000 │
└─────┴──────────────────┴─────────────────────────────────────────┘
↓ ↓ ↓
Sign 2^(124-127) 1.25 (implicit leading 1)
+ 2^-3
Value = (-1)^S × 1.mantissa × 2^(exponent-bias)
= +1 × 1.25 × 2^-3
= 0.15625
The key insight: exponent determines range, mantissa determines precision.
38.2.2 Precision Formats Compared
| Format | Sign | Exponent | Mantissa | Range | Precision (decimal) |
|---|---|---|---|---|---|
| FP64 | 1 | 11 | 52 | ±10^308 | ~16 digits |
| FP32 | 1 | 8 | 23 | ±10^38 | ~7 digits |
| TF32 | 1 | 8 | 10 | ±10^38 | ~3 digits |
| FP16 | 1 | 5 | 10 | ±65504 | ~3 digits |
| BF16 | 1 | 8 | 7 | ±10^38 | ~2 digits |
| FP8 E4M3 | 1 | 4 | 3 | ±448 | ~1 digit |
| FP8 E5M2 | 1 | 5 | 2 | ±57344 | ~0.5 digit |
38.2.3 Why BF16 Exists
BF16 (Brain Float 16) was designed specifically for deep learning. The key insight: neural network gradients need range more than precision.
# FP16 overflow: gradients often exceed 65504
gradient = 100000.0 # Common in early training
fp16_grad = np.float16(gradient) # → inf (overflow!)
# BF16 handles large values: same range as FP32
bf16_grad = torch.tensor(gradient, dtype=torch.bfloat16) # → 100000.0The trade-off is clear:
- FP16: More precision (10-bit mantissa), limited range (max ~65504)
- BF16: Less precision (7-bit mantissa), full FP32 range (max ~10^38)
For training, BF16 usually wins because:
- Gradients can be large (especially early in training)
- Gradient accumulation smooths out precision loss
- Optimizer state is typically kept in FP32 anyway
38.3 Mixed Precision Training: The Standard Approach
Pure low-precision training often fails. Mixed precision—computing in FP16/BF16 but accumulating in FP32—gives you speed without sacrificing convergence.
38.3.1 The Three-Part Recipe
import torch
from torch.cuda.amp import autocast, GradScaler
# 1. Create a scaler to prevent gradient underflow
scaler = GradScaler()
for batch in dataloader:
optimizer.zero_grad()
# 2. Autocast: run forward pass in FP16
with autocast():
outputs = model(batch)
loss = criterion(outputs, targets)
# 3. Scale loss, backward, unscale, step
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()38.3.2 Why Loss Scaling is Necessary
FP16’s limited range causes gradient underflow: small gradients become exactly zero.
FP16 minimum positive normal: 2^-14 ≈ 6.1 × 10^-5
Gradient magnitudes in training:
- Large gradients: 10^-1 to 10^2 ✓ Representable
- Medium gradients: 10^-3 to 10^-1 ✓ Representable
- Small gradients: 10^-5 to 10^-3 ⚠️ Approaching limit
- Tiny gradients: < 10^-5 ✗ Underflow to zero
Loss scaling multiplies the loss (and thus all gradients) by a large factor before backward:
# Without scaling: small gradients → zero
loss.backward() # gradient of 1e-6 becomes 0 in FP16
# With scaling: gradients stay representable
scaled_loss = loss * 65536 # Scale factor
scaled_loss.backward() # gradient of 1e-6 → 0.065 in FP16
# Unscale after: 0.065 / 65536 → 1e-6 (recovered)The GradScaler does this automatically:
- Starts with a large scale (e.g., 65536)
- Monitors for inf/nan in gradients
- If overflow detected: skip update, halve the scale
- If no overflow for N steps: double the scale
38.3.3 Operations That Need FP32
Some operations are numerically unstable in FP16:
# These are automatically run in FP32 by autocast:
- Softmax
- Layer normalization
- Loss functions (cross-entropy, etc.)
- Exponentials and logarithms
- Reductions (sum, mean over large tensors)
# Safe in FP16:
- Matrix multiplications (Tensor cores!)
- Convolutions
- Element-wise operations (ReLU, add, mul)PyTorch’s autocast knows which operations need higher precision:
with autocast():
# matmul runs in FP16 (fast)
hidden = input @ weights
# softmax automatically promoted to FP32 (accurate)
probs = torch.softmax(hidden, dim=-1)
# back to FP16 for next matmul
output = probs @ value_weights38.4 FP8: The New Frontier
NVIDIA’s Hopper (H100) and AMD’s MI300 introduced FP8 support, promising another 2× speedup over FP16.
38.4.1 The Two FP8 Formats
E4M3: More precision, less range
┌─────┬────────┬───────────┐
│ S │ Exp(4) │ Mant(3) │
└─────┴────────┴───────────┘
Range: ±448, Precision: ~3 bits
E5M2: More range, less precision
┌─────┬─────────┬──────────┐
│ S │ Exp(5) │ Mant(2) │
└─────┴─────────┴──────────┘
Range: ±57344, Precision: ~2 bits
The strategy: use different formats for different tensors.
# Typical FP8 assignment:
# - Weights: E4M3 (need precision, values are bounded)
# - Activations: E4M3 (need precision for forward pass)
# - Gradients: E5M2 (need range, gradients can be large)38.4.2 FP8 Training Recipe
FP8 training requires per-tensor scaling because the dynamic range is so limited:
# Conceptual FP8 training (simplified)
def fp8_matmul(A, B, A_scale, B_scale):
# Quantize inputs to FP8
A_fp8 = quantize_to_fp8(A / A_scale)
B_fp8 = quantize_to_fp8(B / B_scale)
# Compute in FP8 (uses tensor cores)
C_fp8 = matmul_fp8(A_fp8, B_fp8)
# Dequantize result
return C_fp8 * A_scale * B_scale
# Scale factors are computed per-tensor based on tensor statistics
def compute_scale(tensor, target_format='e4m3'):
amax = tensor.abs().max()
if target_format == 'e4m3':
max_fp8 = 448.0
else: # e5m2
max_fp8 = 57344.0
return amax / max_fp8Libraries like Transformer Engine handle this automatically:
import transformer_engine.pytorch as te
# Replace nn.Linear with FP8-aware version
class FP8Model(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.linear = te.Linear(hidden_size, hidden_size)
def forward(self, x):
with te.fp8_autocast():
return self.linear(x)38.5 Numerical Instability: Diagnosis and Prevention
38.5.1 Classic Failure Modes
1. Catastrophic Cancellation
When subtracting nearly equal numbers:
# Bad: loses precision
a = 1.0000001
b = 1.0000000
diff = a - b # Expected: 1e-7, Got: 1.1920929e-07 (19% error in FP32)
# In FP16: even worse
a_fp16 = torch.tensor(1.0001, dtype=torch.float16)
b_fp16 = torch.tensor(1.0000, dtype=torch.float16)
diff_fp16 = a_fp16 - b_fp16 # Could be 0 or very wrong2. Accumulation Error
Summing many small values:
# Bad: order matters
values = [1e-8] * 100_000_000
sum(values) # Should be 1.0, often get ~0.99 or worse
# Kahan summation fixes this:
def kahan_sum(values):
total = 0.0
compensation = 0.0
for x in values:
y = x - compensation
t = total + y
compensation = (t - total) - y
total = t
return total3. Softmax Overflow/Underflow
Raw softmax is numerically dangerous:
# Bad: exp(1000) overflows
def naive_softmax(x):
exp_x = torch.exp(x) # Overflow!
return exp_x / exp_x.sum()
# Good: subtract max for stability
def stable_softmax(x):
x_max = x.max()
exp_x = torch.exp(x - x_max) # Now exp of negative numbers
return exp_x / exp_x.sum()38.5.2 Debugging Numerical Issues
Step 1: Detect
def check_numerical_health(tensor, name=""):
"""Check for numerical problems."""
has_nan = torch.isnan(tensor).any()
has_inf = torch.isinf(tensor).any()
if has_nan or has_inf:
print(f"{name}: NaN={has_nan}, Inf={has_inf}")
print(f" Range: [{tensor.min():.2e}, {tensor.max():.2e}]")
print(f" Mean: {tensor.mean():.2e}, Std: {tensor.std():.2e}")
return False
return True
# Use hooks to monitor during training
def register_nan_hooks(model):
def check_hook(module, input, output):
if isinstance(output, torch.Tensor):
if torch.isnan(output).any():
raise RuntimeError(f"NaN in {module.__class__.__name__}")
for module in model.modules():
module.register_forward_hook(check_hook)Step 2: Isolate
# Binary search for the problematic layer
def find_nan_layer(model, input_data):
"""Find which layer first produces NaN."""
x = input_data
for name, module in model.named_modules():
if len(list(module.children())) == 0: # Leaf module
x = module(x)
if torch.isnan(x).any():
print(f"NaN first appears in: {name}")
return name, module
return None, NoneStep 3: Fix
Common fixes for each failure mode:
# Gradient explosion → Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Loss NaN → Check for log(0) or division by zero
loss = F.cross_entropy(logits, targets)
# Add epsilon: F.cross_entropy uses log_softmax internally, but custom losses need care
# Softmax NaN → Use log_softmax instead
log_probs = F.log_softmax(logits, dim=-1) # Stable
probs = log_probs.exp() # If you need probabilities
# Layer norm NaN → Check for zero variance
# PyTorch's LayerNorm has eps parameter (default 1e-5)38.6 Precision and Determinism
38.6.1 The Reproducibility Problem
Same code, same data, different results:
# Run 1: loss = 2.3456789
# Run 2: loss = 2.3456791
# Why? Non-deterministic operations:
# - Atomic additions in reductions (order depends on timing)
# - cuDNN algorithm selection
# - Parallel reductions with different orderings38.6.2 Making Training Deterministic
import torch
import random
import numpy as np
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Force deterministic algorithms
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# PyTorch 1.8+: enable deterministic mode
torch.use_deterministic_algorithms(True)Warning: Determinism often costs performance:
- cuDNN’s fastest algorithms may be non-deterministic
- Deterministic reductions are slower
- Some operations don’t have deterministic implementations
The trade-off:
Benchmark: ResNet-50 training throughput
┌────────────────────────────────────────┐
│ Non-deterministic: 1000 images/sec │
│ Deterministic: 750 images/sec │
│ Slowdown: 25% │
└────────────────────────────────────────┘
38.7 Practical Precision Guidelines
38.7.1 Choosing Your Precision
Training:
| Situation | Recommended | Why |
|---|---|---|
| Starting out | FP32 | Debug first, optimize later |
| Standard training | BF16 mixed precision | Best speed/stability trade-off |
| Memory constrained | FP16 + loss scaling | Works for most models |
| Cutting edge | FP8 + Transformer Engine | Maximum performance, needs tuning |
Inference:
| Situation | Recommended | Why |
|---|---|---|
| Accuracy critical | FP32 or FP16 | Minimal quality loss |
| Throughput focused | INT8 quantization | 2× speedup, small accuracy drop |
| Edge deployment | INT4/INT8 | Fits in limited memory |
38.7.2 The Decision Tree
Is training stable in FP32?
│
├─ No → Debug first (numerical issues exist regardless of precision)
│
└─ Yes → Does your hardware support BF16?
│
├─ Yes → Use BF16 mixed precision
│ └─ Still slow? → Consider FP8
│
└─ No → Use FP16 with loss scaling
└─ Training unstable?
└─ Increase loss scale,
check for ops that need FP32
38.7.3 Verification Checklist
Before deploying with reduced precision:
def verify_precision_safety(model, test_loader, precision='fp16'):
"""Compare reduced precision to FP32 baseline."""
model_fp32 = copy.deepcopy(model).float()
if precision == 'fp16':
model_reduced = copy.deepcopy(model).half()
elif precision == 'bf16':
model_reduced = copy.deepcopy(model).bfloat16()
errors = []
for batch in test_loader:
with torch.no_grad():
out_fp32 = model_fp32(batch.float())
out_reduced = model_reduced(batch.to(model_reduced.dtype))
# Compare in FP32
relative_error = (out_fp32 - out_reduced.float()).abs() / (out_fp32.abs() + 1e-8)
errors.append(relative_error.mean().item())
mean_error = np.mean(errors)
max_error = np.max(errors)
print(f"Mean relative error: {mean_error:.2e}")
print(f"Max relative error: {max_error:.2e}")
# Thresholds depend on your application
assert mean_error < 1e-3, "Average error too high"
assert max_error < 1e-2, "Max error too high"38.8 Key Takeaways
- Know your formats: BF16 for range, FP16 for precision, FP8 for speed
- Mixed precision is standard: FP16/BF16 compute, FP32 accumulation
- Loss scaling prevents underflow: Essential for FP16, less critical for BF16
- Some ops need FP32: Softmax, layer norm, reductions—let autocast handle it
- Debug in FP32 first: If it’s broken in FP32, lower precision won’t fix it
- Verify before deploying: Compare reduced precision outputs to FP32 baseline
38.9 Connection to Other Chapters
- Chapter 4 (GPU Architecture): Tensor cores provide the hardware speedup for reduced precision
- Chapter 12 (Quantization): Post-training quantization for inference complements training-time precision choices
- Chapter 21 (Triton): Writing kernels that correctly handle mixed precision