Skip to content

Profiling Distributed Training

A distributed training run consumes thousands of GPU-hours. Yet most practitioners have no idea where that time goes. Profiling transforms intuition into data, revealing whether you're compute-bound, communication-bound, or simply waiting.

The Question: Your 64-GPU training run achieves 35% MFU (Model FLOP Utilization). Where is the other 65%? Is it communication? Memory bandwidth? Kernel launch overhead? Idle time? Without measurement, optimization is guesswork.

Notation in this chapter: MFU = Model FLOP Utilization, \(T_{\text{step}}\) = step time. See Notation.

Building On: All Previous Parts

This final part synthesizes everything. You understand rooflines and estimation (Part I), optimal resource allocation (Part II), collective costs (Part III), parallelism strategies (Part IV), memory management (Part V), composition (Part VI), and efficiency techniques (Part VII). Now we apply this knowledge: diagnose real systems, investigate bottlenecks, and analyze state-of-the-art training runs.

Code style in Part VIII

The code examples in Part VIII are pedagogical skeletons — they show essential logic and interfaces but omit error handling, edge cases, and production hardening. They are meant to illustrate profiling strategies, not serve as drop-in tools. For production profiling, use PyTorch Profiler, Nsight Systems, and NCCL debug logging directly.

The Profiling Imperative

At scale, inefficiency compounds. A 10% inefficiency on one GPU becomes ~154 GPU-hours wasted per day on 64 GPUs. Understanding exactly where time goes is essential.

What We Measure

Distributed training profiling examines four domains:

1. Compute - Kernel execution time - Tensor Core utilization - Memory bandwidth utilization - FLOPs achieved vs theoretical peak

2. Communication - Collective operation duration - Network bandwidth utilization - Message sizes and frequencies - Overlap with computation

3. Memory - Peak allocation - Fragmentation - Data movement (HBM ↔ host, host ↔ device) - Cache hit rates

4. Orchestration - Kernel launch overhead - Python/framework overhead - Synchronization waits - Load imbalance across workers

The Time Budget

Every training step has a fixed time budget:

\[T_{\text{step}} = T_{\text{forward}} + T_{\text{backward}} + T_{\text{comm}} + T_{\text{optimizer}} + T_{\text{other}}\]

The goal: minimize \(T_{\text{step}}\) while maintaining model quality. Here, \(T_{\text{compute}}\) typically means forward + backward + optimizer kernel time.

With perfect overlap:

\[T_{\text{step}} = \max(T_{\text{compute}}, T_{\text{comm}})\]

Without overlap:

\[T_{\text{step}} = T_{\text{compute}} + T_{\text{comm}}\]

Profiling reveals where you are on this spectrum.

Profiling Tools

NVIDIA Nsight Systems

The gold standard for GPU profiling. Captures:

  • CUDA kernel execution
  • Memory operations
  • NCCL collective calls
  • CPU activity
  • Inter-GPU communication

Basic Usage:

nsys profile -o trace \
    --trace=cuda,nvtx,osrt,cudnn,cublas \
    --sample=cpu \
    python train.py

For Distributed Training:

# On each rank
nsys profile -o trace_rank${RANK} \
    --trace=cuda,nvtx,osrt \
    --cuda-memory-usage=true \
    torchrun --nproc_per_node=8 train.py

Key Command Options:

# Capture NCCL operations explicitly
nsys profile --trace=cuda,nvtx,osrt,nccl \
    --capture-range=cudaProfilerApi \
    python train.py

# Limit trace duration (avoid huge files)
nsys profile --duration=60 -o trace python train.py

# Export to multiple formats
nsys export -o trace.json --type=json trace.nsys-rep

PyTorch Profiler

Built-in profiling with TensorBoard integration:

import torch.profiler as profiler

with profiler.profile(
    activities=[
        profiler.ProfilerActivity.CPU,
        profiler.ProfilerActivity.CUDA,
    ],
    schedule=profiler.schedule(
        wait=1,      # Skip first step
        warmup=1,    # Warmup step
        active=3,    # Profile 3 steps
        repeat=1
    ),
    on_trace_ready=profiler.tensorboard_trace_handler('./logs'),
    record_shapes=True,
    profile_memory=True,
    with_stack=True
) as prof:
    for step, batch in enumerate(dataloader):
        if step >= 5:
            break
        train_step(batch)
        prof.step()

Distributed-Aware Profiling:

import torch.distributed as dist

def profile_distributed_training(model, dataloader, num_steps=5):
    """Profile distributed training with rank-aware output."""
    rank = dist.get_rank()
    world_size = dist.get_world_size()

    profiler_config = profiler.profile(
        activities=[
            profiler.ProfilerActivity.CPU,
            profiler.ProfilerActivity.CUDA,
        ],
        schedule=profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
        on_trace_ready=profiler.tensorboard_trace_handler(
            f'./logs/rank_{rank}'
        ),
        record_shapes=True,
        profile_memory=True,
        with_flops=True,  # Estimate FLOPs
    )

    with profiler_config as prof:
        for step, batch in enumerate(dataloader):
            if step >= num_steps:
                break

            # Add NVTX markers for visibility
            with torch.cuda.nvtx.range(f"step_{step}"):
                with torch.cuda.nvtx.range("forward"):
                    output = model(batch)

                with torch.cuda.nvtx.range("backward"):
                    loss = compute_loss(output)
                    loss.backward()

                with torch.cuda.nvtx.range("optimizer"):
                    optimizer.step()
                    optimizer.zero_grad()

            prof.step()

    # Print summary for rank 0
    if rank == 0:
        print(prof.key_averages().table(
            sort_by="cuda_time_total", row_limit=20
        ))

NVTX Annotations

NVIDIA Tools Extension provides manual annotation:

import torch.cuda.nvtx as nvtx

class NVTXAnnotatedModule(nn.Module):
    """Module with NVTX range annotations."""

    def __init__(self, module, name):
        super().__init__()
        self.module = module
        self.name = name

    def forward(self, x):
        with nvtx.range(self.name):
            return self.module(x)

# Annotate model layers
def annotate_model(model):
    """Add NVTX annotations to all layers."""
    for name, module in model.named_children():
        if len(list(module.children())) > 0:
            annotate_model(module)
        else:
            setattr(model, name, NVTXAnnotatedModule(module, name))

Fine-Grained Collective Annotation:

import torch.distributed as dist

class ProfiledAllReduce:
    """AllReduce with detailed profiling."""

    def __init__(self, process_group=None):
        self.process_group = process_group
        self.call_count = 0
        self.total_bytes = 0
        self.total_time = 0

    def __call__(self, tensor):
        size_bytes = tensor.numel() * tensor.element_size()
        self.total_bytes += size_bytes

        with nvtx.range(f"AllReduce_{self.call_count}_{size_bytes/1e6:.1f}MB"):
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)

            start.record()
            dist.all_reduce(tensor, group=self.process_group)
            end.record()

            torch.cuda.synchronize()
            elapsed = start.elapsed_time(end)
            self.total_time += elapsed

        self.call_count += 1
        return tensor

    def summary(self):
        return {
            'calls': self.call_count,
            'total_bytes_GB': self.total_bytes / 1e9,
            'total_time_ms': self.total_time,
            'avg_bandwidth_GBps': (self.total_bytes / 1e9) / (self.total_time / 1e3)
        }

NCCL Debug Output

Enable detailed NCCL logging:

# Basic info
export NCCL_DEBUG=INFO

# Detailed subsystem logging
export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=ALL

# Trace all operations
export NCCL_DEBUG=TRACE

# Log to file per rank
export NCCL_DEBUG_FILE=nccl_log_%h_%p.txt

Interpreting NCCL Logs:

# Topology detection
NCCL INFO Ring 00 : 0 -> 1 -> 2 -> 3 -> 0

# Algorithm selection
NCCL INFO Channel 00/08 : 0[0] -> 1[1] via P2P/CUMEM

# Operation timing (with TRACE)
NCCL INFO AllReduce: opCount 42 bytes 104857600 datatype 7 op 0

Memory Profiling

Track memory allocation and fragmentation:

import torch

class MemoryProfiler:
    """Track GPU memory usage during training."""

    def __init__(self, device=None):
        self.device = device or torch.cuda.current_device()
        self.snapshots = []

    def snapshot(self, label=""):
        """Capture current memory state."""
        torch.cuda.synchronize()

        allocated = torch.cuda.memory_allocated(self.device)
        reserved = torch.cuda.memory_reserved(self.device)
        max_allocated = torch.cuda.max_memory_allocated(self.device)

        self.snapshots.append({
            'label': label,
            'allocated_GB': allocated / 1e9,
            'reserved_GB': reserved / 1e9,
            'max_allocated_GB': max_allocated / 1e9,
            'fragmentation': 1 - (allocated / reserved) if reserved > 0 else 0
        })

    def reset_peak(self):
        """Reset peak memory tracking."""
        torch.cuda.reset_peak_memory_stats(self.device)

    def report(self):
        """Print memory timeline."""
        print("\n=== Memory Profile ===")
        print(f"{'Label':<30} {'Allocated':>12} {'Reserved':>12} {'Peak':>12} {'Frag':>8}")
        print("-" * 74)
        for snap in self.snapshots:
            print(f"{snap['label']:<30} "
                  f"{snap['allocated_GB']:>10.2f}GB "
                  f"{snap['reserved_GB']:>10.2f}GB "
                  f"{snap['max_allocated_GB']:>10.2f}GB "
                  f"{snap['fragmentation']:>7.1%}")

# Usage during training
mem_profiler = MemoryProfiler()
mem_profiler.reset_peak()

mem_profiler.snapshot("after_model_init")
# ... training ...
mem_profiler.snapshot("after_forward")
mem_profiler.snapshot("after_backward")
mem_profiler.snapshot("after_optimizer_step")

mem_profiler.report()

Detailed Memory Breakdown:

def memory_breakdown():
    """Get detailed memory breakdown by category."""
    snapshot = torch.cuda.memory_snapshot()

    categories = {}
    for block in snapshot:
        # Categorize by allocation context
        if 'frames' in block and block['frames']:
            context = block['frames'][0]['filename']
        else:
            context = 'unknown'

        if context not in categories:
            categories[context] = 0
        categories[context] += block['total_size']

    # Sort by size
    sorted_cats = sorted(categories.items(), key=lambda x: -x[1])

    print("\n=== Memory by Source ===")
    total = sum(categories.values())
    for source, size in sorted_cats[:10]:
        print(f"{source}: {size/1e9:.2f}GB ({size/total*100:.1f}%)")

Interpreting Traces

Nsight Systems Timeline

A typical distributed training trace shows:

Time →
|--Forward--|--Backward----|--AllReduce--|--Optimizer--|
       |                   |<--Overlap-->|
       |--Compute Stream---|
       |---Comm Stream-----|

Key Patterns to Look For:

  1. Gaps between kernels: Indicates CPU overhead or kernel launch latency
  2. Sequential compute and comm: No overlap, potential for optimization
  3. Long collective tails: Straggler workers or network contention
  4. Memory copy operations: Potential for prefetching or pinned memory

Identifying Bottlenecks

class BottleneckAnalyzer:
    """Analyze profiling data to identify bottlenecks."""

    def __init__(self):
        self.timings = {
            'forward': [],
            'backward': [],
            'allreduce': [],
            'optimizer': [],
            'data_loading': [],
        }

    def record(self, phase, duration_ms):
        self.timings[phase].append(duration_ms)

    def analyze(self):
        """Identify the primary bottleneck."""
        results = {}

        for phase, times in self.timings.items():
            if times:
                results[phase] = {
                    'mean_ms': sum(times) / len(times),
                    'max_ms': max(times),
                    'min_ms': min(times),
                    'variance': self._variance(times),
                }

        total = sum(r['mean_ms'] for r in results.values())

        print("\n=== Bottleneck Analysis ===")
        print(f"{'Phase':<15} {'Mean':>10} {'Max':>10} {'Variance':>10} {'%Total':>10}")
        print("-" * 55)

        for phase, stats in sorted(results.items(), key=lambda x: -x[1]['mean_ms']):
            pct = stats['mean_ms'] / total * 100 if total > 0 else 0
            print(f"{phase:<15} {stats['mean_ms']:>8.2f}ms {stats['max_ms']:>8.2f}ms "
                  f"{stats['variance']:>8.2f}ms {pct:>9.1f}%")

        # Identify bottleneck type
        compute_time = results.get('forward', {}).get('mean_ms', 0) + \
                      results.get('backward', {}).get('mean_ms', 0)
        comm_time = results.get('allreduce', {}).get('mean_ms', 0)

        print("\n=== Diagnosis ===")
        if comm_time > compute_time * 1.2:
            print("COMMUNICATION BOUND: Collective operations dominate.")
            print("  Recommendations:")
            print("  - Increase batch size to amortize communication")
            print("  - Enable gradient bucketing if not already")
            print("  - Check for network bottlenecks")
        elif compute_time > comm_time * 1.2:
            print("COMPUTE BOUND: Forward/backward computation dominates.")
            print("  Recommendations:")
            print("  - Good! System is well-utilized")
            print("  - Consider mixed precision if not already using")
            print("  - May be able to overlap more communication")
        else:
            print("BALANCED: Compute and communication roughly equal.")
            print("  Recommendations:")
            print("  - Verify overlap is enabled and effective")
            print("  - This is often the optimal regime")

    def _variance(self, values):
        if len(values) < 2:
            return 0
        mean = sum(values) / len(values)
        return (sum((x - mean) ** 2 for x in values) / len(values)) ** 0.5

Reading Collective Timelines

AllReduce Timeline (ideal overlap):
GPU 0: |--Compute--|--AllReduce--|        |--Compute--|
GPU 1: |--Compute--|--AllReduce--|        |--Compute--|
                   |<--overlap-->|

AllReduce Timeline (straggler):
GPU 0: |--Compute--|--AllReduce--|---wait---|--Compute--|
GPU 1: |--Compute--|------AllReduce--------|--Compute--|
                              ^
                        Slow GPU 1 affects all

Detecting Stragglers:

import torch.distributed as dist

def detect_stragglers(num_iterations=10):
    """Measure timing variance across ranks."""
    rank = dist.get_rank()
    world_size = dist.get_world_size()

    local_times = []

    for _ in range(num_iterations):
        torch.cuda.synchronize()
        start = time.time()

        # Simulate work with some variation
        dummy_work()

        torch.cuda.synchronize()
        local_time = time.time() - start
        local_times.append(local_time)

        # Gather all times to rank 0
        all_times = [torch.zeros(1) for _ in range(world_size)]
        dist.all_gather(all_times, torch.tensor([local_time]))

        if rank == 0:
            times = [t.item() for t in all_times]
            mean_time = sum(times) / len(times)
            max_time = max(times)
            straggler_rank = times.index(max_time)

            if max_time > mean_time * 1.1:  # 10% slower
                print(f"Straggler detected: rank {straggler_rank} "
                      f"({max_time:.3f}s vs mean {mean_time:.3f}s)")

MFU and Efficiency Metrics

Model FLOP Utilization

MFU measures actual compute efficiency:

\[\text{MFU} = \frac{\text{Achieved FLOPs}}{\text{Peak FLOPs}}\]

Calculating MFU:

def calculate_mfu(
    model_flops_per_sample: int,
    batch_size: int,
    step_time_seconds: float,
    num_gpus: int,
    peak_flops_per_gpu: float
) -> float:
    """
    Calculate Model FLOP Utilization.

    Args:
        model_flops_per_sample: Forward + backward FLOPs per sample
        batch_size: Global batch size
        step_time_seconds: Time for one training step
        num_gpus: Number of GPUs used
        peak_flops_per_gpu: Theoretical peak FLOPs (e.g., 312 TFLOPs for A100)

    Returns:
        MFU as a fraction (0-1)
    """
    # Total FLOPs for this step
    total_flops = model_flops_per_sample * batch_size

    # Achieved FLOPs
    achieved_flops = total_flops / step_time_seconds

    # Peak system FLOPs
    peak_flops = peak_flops_per_gpu * num_gpus

    return achieved_flops / peak_flops

# Example: GPT-3 175B on 1024 A100s
model_flops = 6 * 175e9 * 2048  # ~2.15e15 FLOPs per sample (6 * Psi * seq_len)
batch_size = 1024
step_time = 60.0  # seconds
num_gpus = 1024
peak_per_gpu = 312e12  # A100 FP16 peak

mfu = calculate_mfu(model_flops, batch_size, step_time, num_gpus, peak_per_gpu)
print(f"MFU: {mfu:.1%}")  # ~11% for these numbers; higher MFU needs shorter step time

Hardware FLOP Utilization (HFU)

HFU includes rematerialization:

\[\text{HFU} = \frac{\text{Achieved FLOPs including recomputation}}{\text{Peak FLOPs}}\]
def calculate_hfu(
    model_flops_per_sample: int,
    batch_size: int,
    step_time_seconds: float,
    num_gpus: int,
    peak_flops_per_gpu: float,
    recomputation_ratio: float = 1.0  # 1.0 = no recomputation, 2.0 = full recomputation
) -> float:
    """
    Calculate Hardware FLOP Utilization (includes recomputation).

    The recomputation_ratio accounts for activation checkpointing.
    """
    # Effective FLOPs including recomputation
    effective_flops = model_flops_per_sample * batch_size * recomputation_ratio

    achieved_flops = effective_flops / step_time_seconds
    peak_flops = peak_flops_per_gpu * num_gpus

    return achieved_flops / peak_flops

Communication Efficiency

def calculate_comm_efficiency(
    bytes_communicated: int,
    comm_time_seconds: float,
    network_bandwidth_bytes_per_sec: float
) -> float:
    """Calculate communication bandwidth efficiency."""
    achieved_bandwidth = bytes_communicated / comm_time_seconds
    return achieved_bandwidth / network_bandwidth_bytes_per_sec

# Example: AllReduce of 1GB gradients over 100Gbps InfiniBand
bytes_comm = 1e9
comm_time = 0.1  # 100ms
network_bw = 12.5e9  # 100Gbps = 12.5 GB/s

efficiency = calculate_comm_efficiency(bytes_comm, comm_time, network_bw)
print(f"Communication efficiency: {efficiency:.1%}")

Overlap Efficiency

Measure how well computation and communication overlap:

class OverlapEfficiencyTracker:
    """Track overlap between compute and communication."""

    def __init__(self):
        self.compute_time = 0
        self.comm_time = 0
        self.total_time = 0

    def record_step(self, compute_ms, comm_ms, step_ms):
        self.compute_time += compute_ms
        self.comm_time += comm_ms
        self.total_time += step_ms

    def overlap_efficiency(self):
        """
        Calculate overlap efficiency.

        Perfect overlap: total = max(compute, comm)
        No overlap: total = compute + comm

        Efficiency = 1 - (actual - theoretical_min) / (theoretical_max - theoretical_min)
        """
        theoretical_min = max(self.compute_time, self.comm_time)
        theoretical_max = self.compute_time + self.comm_time

        if theoretical_max == theoretical_min:
            return 1.0

        return 1 - (self.total_time - theoretical_min) / (theoretical_max - theoretical_min)

    def report(self):
        print(f"\n=== Overlap Efficiency ===")
        print(f"Total compute time: {self.compute_time:.2f}ms")
        print(f"Total comm time: {self.comm_time:.2f}ms")
        print(f"Total wall time: {self.total_time:.2f}ms")
        print(f"Theoretical min (perfect overlap): {max(self.compute_time, self.comm_time):.2f}ms")
        print(f"Theoretical max (no overlap): {self.compute_time + self.comm_time:.2f}ms")
        print(f"Overlap efficiency: {self.overlap_efficiency():.1%}")

The Alpha-Beta Model in Practice

Measuring α and β

The alpha-beta model predicts collective time:

\[T = \alpha + \frac{n}{\beta}\]

Measuring on Your Hardware:

import time

def measure_alpha_beta(
    process_group,
    sizes_bytes: list,
    num_warmup: int = 5,
    num_measure: int = 20
) -> tuple:
    """
    Measure alpha (latency) and beta (bandwidth) for a process group.

    Returns:
        (alpha_seconds, beta_bytes_per_second)
    """
    rank = dist.get_rank()
    times = []

    for size in sizes_bytes:
        tensor = torch.zeros(size // 4, dtype=torch.float32, device='cuda')

        # Warmup
        for _ in range(num_warmup):
            dist.all_reduce(tensor, group=process_group)
            torch.cuda.synchronize()

        # Measure
        elapsed = []
        for _ in range(num_measure):
            torch.cuda.synchronize()
            start = time.perf_counter()
            dist.all_reduce(tensor, group=process_group)
            torch.cuda.synchronize()
            elapsed.append(time.perf_counter() - start)

        avg_time = sum(elapsed) / len(elapsed)
        times.append((size, avg_time))

    if rank == 0:
        # Linear regression: T = alpha + size/beta
        # Solve for alpha and beta using least squares
        import numpy as np
        sizes = np.array([s for s, _ in times])
        measured = np.array([t for _, t in times])

        # Design matrix: [1, size]
        X = np.column_stack([np.ones_like(sizes), sizes])
        # Solve: [alpha, 1/beta] = (X^T X)^{-1} X^T y
        coeffs = np.linalg.lstsq(X, measured, rcond=None)[0]

        alpha = coeffs[0]
        beta = 1 / coeffs[1]

        print(f"\n=== Alpha-Beta Measurement ===")
        print(f"Alpha (latency): {alpha*1e6:.2f} μs")
        print(f"Beta (bandwidth): {beta/1e9:.2f} GB/s")
        print(f"\nPredicted times:")
        for size, actual in times:
            predicted = alpha + size / beta
            error = abs(predicted - actual) / actual * 100
            print(f"  {size/1e6:.1f}MB: predicted={predicted*1e3:.2f}ms, "
                  f"actual={actual*1e3:.2f}ms, error={error:.1f}%")

        return alpha, beta

    return None, None

# Measure with various sizes
sizes = [1024, 64*1024, 256*1024, 1024*1024, 4*1024*1024, 16*1024*1024, 64*1024*1024]
alpha, beta = measure_alpha_beta(dist.group.WORLD, sizes)

Using the Model for Prediction

class CollectiveTimePredictor:
    """Predict collective operation times using alpha-beta model."""

    def __init__(self, alpha_seconds: float, beta_bytes_per_sec: float, world_size: int):
        self.alpha = alpha_seconds
        self.beta = beta_bytes_per_sec
        self.world_size = world_size

    def allreduce_ring(self, size_bytes: int) -> float:
        """Predict ring AllReduce time."""
        # Ring: 2(P-1) messages, each of size n/P
        # Total: 2(P-1)/P * n bytes sent
        effective_size = 2 * (self.world_size - 1) / self.world_size * size_bytes
        num_steps = 2 * (self.world_size - 1)
        return num_steps * self.alpha + effective_size / self.beta

    def allgather(self, size_bytes: int) -> float:
        """Predict AllGather time (size_bytes is per-rank input)."""
        # Each rank receives (P-1) chunks of size size_bytes
        effective_size = (self.world_size - 1) * size_bytes
        num_steps = self.world_size - 1
        return num_steps * self.alpha + effective_size / self.beta

    def reduce_scatter(self, size_bytes: int) -> float:
        """Predict ReduceScatter time (size_bytes is per-rank input)."""
        # Each rank sends/receives (P-1)/P of its input
        effective_size = (self.world_size - 1) / self.world_size * size_bytes
        num_steps = self.world_size - 1
        return num_steps * self.alpha + effective_size / self.beta

    def alltoall(self, size_bytes: int) -> float:
        """Predict AlltoAll time (size_bytes is per-rank total input)."""
        # Each rank sends (P-1)/P of its input
        effective_size = (self.world_size - 1) / self.world_size * size_bytes
        num_steps = self.world_size - 1
        return num_steps * self.alpha + effective_size / self.beta

    def compare_algorithms(self, size_bytes: int):
        """Compare different algorithms for a given size."""
        ring = self.allreduce_ring(size_bytes)

        # Tree algorithm (better for small messages)
        tree_steps = 2 * math.ceil(math.log2(self.world_size))
        tree_time = tree_steps * self.alpha + size_bytes / self.beta * 2

        print(f"\n=== AllReduce Comparison for {size_bytes/1e6:.1f}MB ===")
        print(f"Ring: {ring*1e3:.2f}ms")
        print(f"Tree: {tree_time*1e3:.2f}ms")
        print(f"Recommended: {'Ring' if ring < tree_time else 'Tree'}")

Profiling Different Parallelism Strategies

Data Parallelism Profiling

Key metrics to track:

class DDPProfiler:
    """Profiler specialized for DistributedDataParallel."""

    def __init__(self, model):
        self.model = model
        self.bucket_times = []
        self.hook_overheads = []
        self.gradient_sizes = []

        # Instrument buckets
        self._instrument_buckets()

    def _instrument_buckets(self):
        """Add timing instrumentation to DDP buckets."""
        # Access internal bucket info (PyTorch internals)
        if hasattr(self.model, '_module_copies'):
            # Extract bucket information
            pass

    def profile_step(self, batch):
        """Profile a single training step."""
        timings = {}

        # Forward pass
        torch.cuda.synchronize()
        forward_start = time.perf_counter()
        output = self.model(batch)
        torch.cuda.synchronize()
        timings['forward'] = time.perf_counter() - forward_start

        # Backward pass (triggers AllReduce)
        loss = output.sum()  # Dummy loss
        backward_start = time.perf_counter()
        loss.backward()
        torch.cuda.synchronize()
        timings['backward_with_allreduce'] = time.perf_counter() - backward_start

        return timings

    def analyze_bucket_efficiency(self):
        """Analyze if bucket sizes are optimal."""
        if not self.bucket_times:
            print("No bucket timing data collected")
            return

        avg_bucket_time = sum(self.bucket_times) / len(self.bucket_times)

        print(f"\n=== Bucket Analysis ===")
        print(f"Number of buckets: {len(self.bucket_times)}")
        print(f"Average bucket AllReduce: {avg_bucket_time*1e3:.2f}ms")
        print(f"Total bucket AllReduce: {sum(self.bucket_times)*1e3:.2f}ms")

Tensor Parallelism Profiling

class TPProfiler:
    """Profiler for tensor parallelism."""

    def __init__(self, tp_degree: int):
        self.tp_degree = tp_degree
        self.allreduce_times = []
        self.allgather_times = []
        self.split_times = []

    def profile_layer(self, layer_fn, input_tensor):
        """Profile a tensor-parallel layer."""
        results = {}

        torch.cuda.synchronize()
        start = time.perf_counter()

        output = layer_fn(input_tensor)

        torch.cuda.synchronize()
        results['total_time'] = time.perf_counter() - start

        # Estimate communication fraction
        # For column-parallel linear: AllReduce after
        # For row-parallel linear: no communication

        return results

    def analyze_communication_fraction(self):
        """Analyze fraction of time spent in TP communication."""
        total_ar = sum(self.allreduce_times)
        total_ag = sum(self.allgather_times)

        print(f"\n=== Tensor Parallelism Communication ===")
        print(f"Total AllReduce time: {total_ar*1e3:.2f}ms ({len(self.allreduce_times)} calls)")
        print(f"Total AllGather time: {total_ag*1e3:.2f}ms ({len(self.allgather_times)} calls)")

Pipeline Parallelism Profiling

class PPProfiler:
    """Profiler for pipeline parallelism."""

    def __init__(self, num_stages: int, num_microbatches: int):
        self.num_stages = num_stages
        self.num_microbatches = num_microbatches
        self.stage_times = [[] for _ in range(num_stages)]
        self.bubble_time = 0

    def profile_schedule(self, schedule_fn):
        """Profile a pipeline schedule execution."""
        rank = dist.get_rank()
        stage = rank  # Assuming 1 stage per rank

        results = {
            'forward_times': [],
            'backward_times': [],
            'send_times': [],
            'recv_times': [],
            'idle_time': 0,
        }

        # Execute schedule with timing
        # ... implementation depends on pipeline framework

        return results

    def calculate_bubble_fraction(self):
        """Calculate pipeline bubble fraction."""
        # Bubble = (P - 1) * microbatch_time / total_time
        p = self.num_stages
        m = self.num_microbatches

        # For 1F1B schedule
        bubble_fraction = (p - 1) / (m + p - 1)

        print(f"\n=== Pipeline Bubble Analysis ===")
        print(f"Stages: {p}, Microbatches: {m}")
        print(f"Theoretical bubble fraction: {bubble_fraction:.1%}")
        print(f"Theoretical efficiency: {1 - bubble_fraction:.1%}")

ZeRO Profiling

class ZeROProfiler:
    """Profiler for ZeRO-style sharding."""

    def __init__(self, stage: int, world_size: int):
        self.stage = stage  # 1, 2, or 3
        self.world_size = world_size
        self.gather_times = []
        self.scatter_times = []
        self.optimizer_times = []

    def profile_forward(self, model, input_batch):
        """Profile forward pass with parameter gathering."""
        results = {'gather_time': 0, 'compute_time': 0}

        if self.stage == 3:
            # Track AllGather for each layer
            for name, param in model.named_parameters():
                gather_start = time.perf_counter()
                # AllGather param
                torch.cuda.synchronize()
                results['gather_time'] += time.perf_counter() - gather_start

        compute_start = time.perf_counter()
        output = model(input_batch)
        torch.cuda.synchronize()
        results['compute_time'] = time.perf_counter() - compute_start

        return results

    def analyze_memory_comm_tradeoff(self):
        """Analyze memory savings vs communication overhead."""
        print(f"\n=== ZeRO Stage {self.stage} Analysis ===")

        # Memory reduction factors (for total model state: params + grads + optimizer)
        P = self.world_size
        mem_factors = {
            1: 16 / (4 + 12 / P),
            2: 16 / (2 + 14 / P),
            3: P,
        }

        # Communication overhead factors (rough bandwidth relative to DP)
        comm_factors = {1: 1.0, 2: 2.0, 3: 3.0}

        print(f"Memory reduction: {mem_factors[self.stage]:.1f}x (approx)")
        print(f"Communication overhead: {comm_factors[self.stage]:.1f}x vs DP (approx)")

Distributed Profiling Coordination

Synchronized Profiling Across Ranks

class DistributedProfiler:
    """Coordinate profiling across all ranks."""

    def __init__(self, output_dir: str):
        self.rank = dist.get_rank()
        self.world_size = dist.get_world_size()
        self.output_dir = Path(output_dir) / f"rank_{self.rank}"
        self.output_dir.mkdir(parents=True, exist_ok=True)

    def profile_synchronized(self, train_fn, num_steps: int = 5):
        """
        Profile training with synchronized start across ranks.

        Ensures all ranks start profiling at the same time for
        aligned timelines.
        """
        # Barrier to synchronize start
        dist.barrier()

        # Start profiling
        with torch.profiler.profile(
            activities=[
                torch.profiler.ProfilerActivity.CPU,
                torch.profiler.ProfilerActivity.CUDA,
            ],
            on_trace_ready=torch.profiler.tensorboard_trace_handler(str(self.output_dir)),
            record_shapes=True,
            profile_memory=True,
            with_stack=True,
        ) as prof:
            for step in range(num_steps):
                with torch.cuda.nvtx.range(f"step_{step}"):
                    train_fn()
                prof.step()

        # Barrier to synchronize end
        dist.barrier()

        if self.rank == 0:
            print(f"Profiling complete. Traces saved to {self.output_dir.parent}")

    def aggregate_statistics(self, local_stats: dict) -> dict:
        """Aggregate statistics across all ranks."""
        # Gather all local stats to rank 0
        all_stats = [None for _ in range(self.world_size)]
        dist.all_gather_object(all_stats, local_stats)

        if self.rank == 0:
            # Aggregate
            aggregated = {}
            for key in local_stats.keys():
                values = [s[key] for s in all_stats if key in s]
                aggregated[key] = {
                    'mean': sum(values) / len(values),
                    'min': min(values),
                    'max': max(values),
                    'std': (sum((v - sum(values)/len(values))**2 for v in values) / len(values)) ** 0.5
                }
            return aggregated
        return None

Cross-Rank Timeline Alignment

def align_timelines(traces_dir: Path):
    """
    Align profiler traces from multiple ranks using barrier timestamps.

    This helps identify relative timing across ranks.
    """
    import json

    traces = []
    for trace_file in traces_dir.glob("rank_*/trace.json"):
        with open(trace_file) as f:
            trace = json.load(f)
            rank = int(trace_file.parent.name.split('_')[1])
            traces.append((rank, trace))

    # Find barrier events in each trace
    barrier_times = {}
    for rank, trace in traces:
        for event in trace['traceEvents']:
            if 'name' in event and 'barrier' in event['name'].lower():
                if rank not in barrier_times:
                    barrier_times[rank] = []
                barrier_times[rank].append(event['ts'])

    # Calculate offsets relative to rank 0
    offsets = {0: 0}
    if 0 in barrier_times:
        ref_time = barrier_times[0][0]
        for rank in barrier_times:
            if rank != 0:
                offsets[rank] = barrier_times[rank][0] - ref_time

    print("Timeline offsets (μs):", offsets)
    return offsets

Practical Profiling Workflow

The Investigation Protocol

Step 1: Quick Baseline
├── Single step timing
├── GPU utilization (nvidia-smi)
└── Basic MFU estimate

Step 2: Identify Bottleneck Category
├── Compute vs Communication ratio
├── Memory pressure indicators
└── Straggler detection

Step 3: Detailed Analysis
├── Full Nsight trace (2-3 steps)
├── Memory breakdown
└── Collective timing breakdown

Step 4: Root Cause
├── Kernel-level analysis
├── Algorithm selection validation
└── Overlap efficiency measurement

Step 5: Validate Fix
├── A/B comparison
├── Confirm improvement
└── Check for regressions

Quick Health Check

def quick_health_check(model, dataloader, num_steps=3):
    """Fast profiling to identify obvious issues."""
    rank = dist.get_rank()

    timings = []
    for i, batch in enumerate(dataloader):
        if i >= num_steps:
            break

        torch.cuda.synchronize()
        start = time.perf_counter()

        # Training step
        with torch.cuda.nvtx.range(f"step_{i}"):
            output = model(batch)
            loss = output.sum()
            loss.backward()

        torch.cuda.synchronize()
        elapsed = time.perf_counter() - start
        timings.append(elapsed)

    # Gather from all ranks
    all_timings = [None] * dist.get_world_size()
    dist.all_gather_object(all_timings, timings)

    if rank == 0:
        print("\n=== Quick Health Check ===")
        for r, times in enumerate(all_timings):
            avg = sum(times) / len(times)
            print(f"Rank {r}: avg={avg*1e3:.1f}ms, times={[f'{t*1e3:.1f}' for t in times]}")

        # Check for stragglers
        avg_times = [sum(t)/len(t) for t in all_timings]
        mean = sum(avg_times) / len(avg_times)
        max_time = max(avg_times)

        if max_time > mean * 1.1:
            slowest = avg_times.index(max_time)
            print(f"\n⚠️  Straggler detected: Rank {slowest} is {max_time/mean:.0%} of mean")
        else:
            print("\n✓ No obvious stragglers")

Comparative Profiling

class ABProfiler:
    """Compare two configurations."""

    def __init__(self):
        self.results_a = []
        self.results_b = []

    def profile_config(self, name: str, setup_fn, train_fn, num_steps: int = 10):
        """Profile a configuration."""
        setup_fn()  # Apply configuration

        timings = []
        for _ in range(num_steps):
            torch.cuda.synchronize()
            start = time.perf_counter()
            train_fn()
            torch.cuda.synchronize()
            timings.append(time.perf_counter() - start)

        if name == 'A':
            self.results_a = timings
        else:
            self.results_b = timings

    def compare(self):
        """Compare A and B configurations."""
        if not self.results_a or not self.results_b:
            print("Need both A and B results")
            return

        avg_a = sum(self.results_a) / len(self.results_a)
        avg_b = sum(self.results_b) / len(self.results_b)

        print("\n=== A/B Comparison ===")
        print(f"Config A: {avg_a*1e3:.2f}ms avg")
        print(f"Config B: {avg_b*1e3:.2f}ms avg")
        print(f"Difference: {(avg_b - avg_a)/avg_a*100:+.1f}%")

        # Statistical significance (simple t-test)
        if len(self.results_a) >= 5 and len(self.results_b) >= 5:
            from scipy import stats
            t_stat, p_value = stats.ttest_ind(self.results_a, self.results_b)
            print(f"p-value: {p_value:.4f} ({'significant' if p_value < 0.05 else 'not significant'})")

Common Issues and Diagnostics

Issue: Low MFU Despite Powerful Hardware

def diagnose_low_mfu():
    """Diagnostic checklist for low MFU."""
    checks = [
        ("Mixed precision enabled?", "torch.cuda.amp.autocast"),
        ("Tensor Cores used?", "Check for TF32/FP16 kernels in trace"),
        ("Batch size sufficient?", "Small batches → kernel launch overhead dominates"),
        ("Memory bandwidth limited?", "Check memory throughput in Nsight"),
        ("Python overhead?", "Check CPU utilization, GIL contention"),
        ("Data loading bottleneck?", "Profile DataLoader separately"),
    ]

    print("\n=== Low MFU Diagnostic Checklist ===")
    for check, how in checks:
        print(f"[ ] {check}")
        print(f"    → {how}")

Issue: Communication Taking Too Long

def diagnose_slow_communication():
    """Diagnostic for communication bottlenecks."""
    checks = [
        "Measure actual vs theoretical bandwidth",
        "Check for network congestion (multiple jobs)",
        "Verify NCCL algorithm selection (NCCL_DEBUG=INFO)",
        "Check for imbalanced work (stragglers)",
        "Verify overlap is working (Nsight timeline)",
        "Check bucket sizes for DDP",
    ]

    print("\n=== Slow Communication Diagnostic ===")
    for i, check in enumerate(checks, 1):
        print(f"{i}. {check}")

Issue: Memory Pressure

def diagnose_memory_pressure():
    """Diagnostic for memory issues."""

    # Check current state
    allocated = torch.cuda.memory_allocated() / 1e9
    reserved = torch.cuda.memory_reserved() / 1e9
    max_alloc = torch.cuda.max_memory_allocated() / 1e9

    print(f"\n=== Memory Diagnostic ===")
    print(f"Currently allocated: {allocated:.2f}GB")
    print(f"Reserved by PyTorch: {reserved:.2f}GB")
    print(f"Peak allocation: {max_alloc:.2f}GB")
    print(f"Fragmentation: {1 - allocated/reserved:.1%}")

    if max_alloc > reserved * 0.95:
        print("\n⚠️  Near memory limit - consider:")
        print("   - Gradient checkpointing")
        print("   - Smaller batch size")
        print("   - ZeRO stage 2 or 3")
        print("   - Offloading")

Exercises

  1. MFU Measurement: Implement a complete MFU measurement for your model. Compare theoretical FLOPs (from model architecture) with achieved FLOPs (from step time). What efficiency do you achieve?
Solution

MFU measurement implementation:

import torch
import time
from dataclasses import dataclass

@dataclass
class ModelConfig:
    hidden_dim: int
    num_layers: int
    num_heads: int
    vocab_size: int
    seq_len: int
    batch_size: int

def count_flops_per_token(config: ModelConfig) -> int:
    """Count FLOPs per token for a transformer."""
    H = config.hidden_dim
    L = config.num_layers
    V = config.vocab_size
    S = config.seq_len

    # Per-layer FLOPs
    # Attention: 4H² (QKV proj) + 2S·H (attn scores) + 2S·H (attn output) + H² (output proj)
    # Approximation: 4H² + 2H² = 6H² for projections, plus O(S·H) for attention itself
    attn_flops = 4 * H * H + 2 * S * H + 2 * S * H + H * H  # ~5H² + 4SH

    # MLP: 2 * (H * 4H) + 2 * (4H * H) = 16H²
    mlp_flops = 8 * H * H + 8 * H * H  # 16H²

    # Per-layer total
    layer_flops = attn_flops + mlp_flops

    # All layers
    total_flops = L * layer_flops

    # Embedding and output projection: 2 * V * H
    embedding_flops = 2 * V * H

    # Total per token
    return total_flops + embedding_flops

def measure_mfu(model, config: ModelConfig, num_warmup=5, num_measure=20):
    """Measure Model FLOP Utilization."""
    device = next(model.parameters()).device

    # Theoretical FLOPs per step
    flops_per_token = count_flops_per_token(config)
    tokens_per_step = config.batch_size * config.seq_len

    # Forward: F, Backward: 2F (gradient for weights and activations)
    flops_per_step = 6 * flops_per_token * tokens_per_step  # 6 = 1 fwd + 2 bwd

    # Peak FLOPs for GPU
    # H100: 989 TFLOP/s (dense FP16/BF16)
    # A100: 312 TFLOP/s (FP16 Tensor Core)
    gpu_name = torch.cuda.get_device_name(device)
    if 'H100' in gpu_name:
        peak_flops = 989e12
    elif 'A100' in gpu_name:
        peak_flops = 312e12
    else:
        peak_flops = 100e12  # Conservative estimate

    # Create dummy input
    input_ids = torch.randint(0, config.vocab_size,
                              (config.batch_size, config.seq_len),
                              device=device)

    # Warmup
    for _ in range(num_warmup):
        output = model(input_ids)
        loss = output.sum()
        loss.backward()
        model.zero_grad()

    torch.cuda.synchronize()

    # Measure
    start = time.time()
    for _ in range(num_measure):
        output = model(input_ids)
        loss = output.sum()
        loss.backward()
        model.zero_grad()
        torch.cuda.synchronize()
    elapsed = time.time() - start

    # Calculate MFU
    time_per_step = elapsed / num_measure
    achieved_flops = flops_per_step / time_per_step
    mfu = achieved_flops / peak_flops

    return {
        'time_per_step_ms': time_per_step * 1000,
        'flops_per_step': flops_per_step,
        'achieved_tflops': achieved_flops / 1e12,
        'peak_tflops': peak_flops / 1e12,
        'mfu': mfu
    }

# Example usage
def run_mfu_measurement():
    from transformers import AutoModelForCausalLM, AutoConfig

    config = ModelConfig(
        hidden_dim=4096,
        num_layers=32,
        num_heads=32,
        vocab_size=32000,
        seq_len=2048,
        batch_size=4
    )

    model_config = AutoConfig.from_pretrained("meta-llama/Llama-2-7b-hf")
    model = AutoModelForCausalLM.from_config(model_config).cuda().half()

    results = measure_mfu(model, config)

    print(f"Time per step: {results['time_per_step_ms']:.2f} ms")
    print(f"Achieved: {results['achieved_tflops']:.1f} TFLOP/s")
    print(f"Peak: {results['peak_tflops']:.1f} TFLOP/s")
    print(f"MFU: {results['mfu']:.1%}")

# run_mfu_measurement()

Expected results (7B model, H100):

Metric Value
Time per step ~180 ms
Achieved ~450 TFLOP/s
Peak 989 TFLOP/s
MFU ~23%

Note: Single-GPU MFU is typically lower (20-30%) than multi-GPU with tensor parallelism (40-50%) due to memory bandwidth limitations.

\[\boxed{\text{Typical single-GPU MFU: 20-30\%; optimized multi-GPU: 40-50\%}}\]
  1. Alpha-Beta Calibration: Measure α and β for your cluster. Use multiple message sizes (1KB to 1GB). Plot predicted vs actual times. How accurate is the linear model?
Solution

Alpha-beta calibration implementation:

import torch
import torch.distributed as dist
import numpy as np
import time
from scipy import stats

def calibrate_alpha_beta(process_group=None, sizes=None, num_trials=50):
    """Calibrate alpha-beta model for collective communication."""

    if sizes is None:
        # Log-spaced sizes from 1KB to 1GB
        sizes = [int(s) for s in np.logspace(10, 30, 20, base=2)]  # 1KB to 1GB

    device = torch.device('cuda')
    results = []

    for size in sizes:
        # Create buffer
        tensor = torch.ones(size // 4, dtype=torch.float32, device=device)

        # Warmup
        for _ in range(5):
            dist.all_reduce(tensor, group=process_group)
        torch.cuda.synchronize()

        # Measure
        times = []
        for _ in range(num_trials):
            torch.cuda.synchronize()
            start = time.time()
            dist.all_reduce(tensor, group=process_group)
            torch.cuda.synchronize()
            elapsed = time.time() - start
            times.append(elapsed)

        median_time = np.median(times)
        results.append({
            'size': size,
            'time': median_time,
            'time_std': np.std(times)
        })

    # Fit linear model: T = alpha + size / beta
    sizes_arr = np.array([r['size'] for r in results])
    times_arr = np.array([r['time'] for r in results])

    # Linear regression: T = alpha + size/beta
    # Rewrite as: T = alpha + (1/beta) * size
    slope, intercept, r_value, _, _ = stats.linregress(sizes_arr, times_arr)

    alpha = intercept  # Latency (seconds)
    beta = 1 / slope    # Bandwidth (bytes/second)

    # Compute R² and prediction accuracy
    predicted = alpha + sizes_arr / beta
    relative_errors = np.abs(predicted - times_arr) / times_arr

    return {
        'alpha_us': alpha * 1e6,  # Convert to microseconds
        'beta_gbps': beta * 8 / 1e9,  # Convert to Gbps
        'r_squared': r_value ** 2,
        'mean_relative_error': np.mean(relative_errors),
        'max_relative_error': np.max(relative_errors),
        'measurements': results
    }

def plot_calibration(results):
    """Plot predicted vs actual times."""
    import matplotlib.pyplot as plt

    sizes = [r['size'] for r in results['measurements']]
    times = [r['time'] * 1000 for r in results['measurements']]  # ms

    alpha = results['alpha_us'] / 1000  # ms
    beta = results['beta_gbps'] * 1e9 / 8  # bytes/s
    predicted = [alpha + s / beta * 1000 for s in sizes]

    plt.figure(figsize=(10, 6))
    plt.loglog(sizes, times, 'o-', label='Measured', markersize=8)
    plt.loglog(sizes, predicted, '--', label=f'Predicted (α={results["alpha_us"]:.1f}μs, β={results["beta_gbps"]:.1f}Gbps)')
    plt.xlabel('Message Size (bytes)')
    plt.ylabel('Time (ms)')
    plt.title(f'AllReduce Alpha-Beta Calibration (R²={results["r_squared"]:.4f})')
    plt.legend()
    plt.grid(True, which='both', ls='-', alpha=0.2)
    plt.savefig('alpha_beta_calibration.png', dpi=150)
    print("Saved plot to alpha_beta_calibration.png")

def run_calibration():
    dist.init_process_group('nccl')
    rank = dist.get_rank()

    results = calibrate_alpha_beta()

    if rank == 0:
        print(f"\nAlpha-Beta Calibration Results:")
        print(f"  α (latency): {results['alpha_us']:.1f} μs")
        print(f"  β (bandwidth): {results['beta_gbps']:.1f} Gbps")
        print(f"  R²: {results['r_squared']:.4f}")
        print(f"  Mean error: {results['mean_relative_error']:.1%}")
        print(f"  Max error: {results['max_relative_error']:.1%}")

        # plot_calibration(results)

# run_calibration()

Expected results (8×H100 with NVLink):

Parameter Value
α (latency) 5-15 μs
β (bandwidth) 400-900 Gbps
0.97-0.99
Mean error 5-15%

Linear model accuracy:

The α+size/β model works well for large messages but can have 20-50% error for small messages where: - Kernel launch overhead dominates - Ring algorithm startup costs are significant - NCCL chunking behavior causes non-linear scaling

\[\boxed{\alpha \approx 10\mu\text{s}, \beta \approx 500\text{ Gbps for NVLink; R}^2 > 0.95}\]
  1. Overlap Efficiency: Profile your DDP training with NVTX annotations. Calculate overlap efficiency. What percentage of communication is hidden behind computation?
Solution

NVTX-annotated overlap profiling:

import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import time

# Try to import NVTX
try:
    import nvtx
    HAS_NVTX = True
except ImportError:
    HAS_NVTX = False
    class nvtx:
        @staticmethod
        def range(name):
            return contextlib.nullcontext()

class OverlapProfiler:
    def __init__(self, model):
        self.model = model
        self.compute_times = []
        self.comm_times = []
        self.total_times = []

    def profile_step(self, input_data, num_trials=10):
        """Profile a training step with NVTX markers."""

        # Measure compute-only time (no DDP)
        self.model.module.zero_grad()
        torch.cuda.synchronize()
        start = time.time()
        for _ in range(num_trials):
            with nvtx.range("forward"):
                output = self.model.module(input_data)
                loss = output.sum()
            with nvtx.range("backward_compute"):
                loss.backward()
            self.model.module.zero_grad()
        torch.cuda.synchronize()
        compute_only = (time.time() - start) / num_trials

        # Measure comm-only time
        # Simulate by doing AllReduce on gradient-sized tensors
        torch.cuda.synchronize()
        start = time.time()
        for _ in range(num_trials):
            with nvtx.range("allreduce"):
                for p in self.model.parameters():
                    if p.requires_grad:
                        fake_grad = torch.ones_like(p)
                        dist.all_reduce(fake_grad)
        torch.cuda.synchronize()
        comm_only = (time.time() - start) / num_trials

        # Measure total time with DDP overlap
        self.model.zero_grad()
        torch.cuda.synchronize()
        start = time.time()
        for _ in range(num_trials):
            with nvtx.range("ddp_step"):
                with nvtx.range("forward"):
                    output = self.model(input_data)
                    loss = output.sum()
                with nvtx.range("backward_with_comm"):
                    loss.backward()
            self.model.zero_grad()
        torch.cuda.synchronize()
        total_time = (time.time() - start) / num_trials

        return {
            'compute_only_ms': compute_only * 1000,
            'comm_only_ms': comm_only * 1000,
            'total_time_ms': total_time * 1000,
            'sequential_ms': (compute_only + comm_only) * 1000,
            'overlap_efficiency': self._calc_overlap_efficiency(
                compute_only, comm_only, total_time
            )
        }

    def _calc_overlap_efficiency(self, compute, comm, total):
        """
        Overlap efficiency = time saved / potential savings

        If total = max(compute, comm): perfect overlap (100%)
        If total = compute + comm: no overlap (0%)
        """
        sequential = compute + comm
        best_case = max(compute, comm)
        potential_savings = sequential - best_case
        actual_savings = sequential - total

        if potential_savings <= 0:
            return 1.0  # Already at best case

        return actual_savings / potential_savings

def run_overlap_profiling():
    dist.init_process_group('nccl')
    rank = dist.get_rank()

    # Create model
    model = nn.Sequential(*[nn.Linear(4096, 4096) for _ in range(20)]).cuda()
    model = DDP(model)

    profiler = OverlapProfiler(model)
    input_data = torch.randn(32, 4096).cuda()

    results = profiler.profile_step(input_data)

    if rank == 0:
        print(f"\nOverlap Profiling Results:")
        print(f"  Compute only: {results['compute_only_ms']:.2f} ms")
        print(f"  Comm only: {results['comm_only_ms']:.2f} ms")
        print(f"  Sequential (no overlap): {results['sequential_ms']:.2f} ms")
        print(f"  Actual total: {results['total_time_ms']:.2f} ms")
        print(f"  Overlap efficiency: {results['overlap_efficiency']:.1%}")

        hidden_comm = results['comm_only_ms'] - (results['total_time_ms'] - results['compute_only_ms'])
        print(f"  Communication hidden: {hidden_comm:.2f} ms ({hidden_comm/results['comm_only_ms']:.1%})")

# run_overlap_profiling()

Expected results:

Metric Value
Compute only ~15 ms
Comm only ~20 ms
Sequential ~35 ms
Actual total ~24 ms
Overlap efficiency ~79%
Comm hidden ~11 ms (55%)

Interpretation:

  • 79% overlap efficiency means we save 79% of the potential savings
  • 55% of communication time is hidden behind computation
  • The remaining 45% is exposed communication (critical path)
\[\boxed{\text{Typical DDP overlap efficiency: 50-80\%; 40-70\% of comm hidden}}\]
  1. Straggler Detection: Run training on 8 GPUs. Artificially slow one GPU (insert sleep). Measure the impact on overall throughput. Implement automatic straggler detection.
Solution

Straggler detection implementation:

import torch
import torch.distributed as dist
import time
import numpy as np

class StragglerDetector:
    def __init__(self, window_size=100, threshold_std=2.0):
        self.window_size = window_size
        self.threshold_std = threshold_std
        self.step_times = []
        self.rank = dist.get_rank()
        self.world_size = dist.get_world_size()

    def record_step(self, step_time):
        """Record step time and check for stragglers."""
        self.step_times.append(step_time)
        if len(self.step_times) > self.window_size:
            self.step_times.pop(0)

    def gather_times(self, local_time):
        """Gather step times from all ranks."""
        times_tensor = torch.tensor([local_time], device='cuda')
        all_times = [torch.zeros(1, device='cuda') for _ in range(self.world_size)]
        dist.all_gather(all_times, times_tensor)
        return [t.item() for t in all_times]

    def detect_straggler(self, all_times):
        """Detect if any rank is a straggler."""
        times = np.array(all_times)
        mean_time = np.mean(times)
        std_time = np.std(times)

        if std_time < 1e-6:  # All times identical
            return None, {}

        stragglers = []
        for rank, t in enumerate(times):
            z_score = (t - mean_time) / std_time
            if z_score > self.threshold_std:
                stragglers.append({
                    'rank': rank,
                    'time': t,
                    'z_score': z_score,
                    'slowdown': t / mean_time
                })

        return stragglers, {
            'mean': mean_time,
            'std': std_time,
            'max': np.max(times),
            'min': np.min(times),
            'spread': np.max(times) / np.min(times)
        }

def simulate_straggler(model, straggler_rank=3, slowdown_ms=50):
    """Simulate a straggler by adding artificial delay."""
    rank = dist.get_rank()

    detector = StragglerDetector()
    input_data = torch.randn(32, 4096).cuda()

    throughputs = {'normal': [], 'straggler': []}

    for phase in ['normal', 'straggler']:
        for step in range(50):
            torch.cuda.synchronize()
            start = time.time()

            # Forward + backward
            output = model(input_data)
            loss = output.sum()
            loss.backward()

            # Simulate straggler
            if phase == 'straggler' and rank == straggler_rank:
                time.sleep(slowdown_ms / 1000)

            # AllReduce (synchronization point)
            for p in model.parameters():
                if p.grad is not None:
                    dist.all_reduce(p.grad)

            torch.cuda.synchronize()
            step_time = time.time() - start

            # Gather and analyze
            all_times = detector.gather_times(step_time)
            stragglers, stats = detector.detect_straggler(all_times)

            if rank == 0:
                throughputs[phase].append(1000 / stats['max'])  # samples/sec

                if stragglers and step % 10 == 0:
                    print(f"Step {step}: Straggler detected!")
                    for s in stragglers:
                        print(f"  Rank {s['rank']}: {s['time']*1000:.1f}ms "
                              f"(z={s['z_score']:.1f}, {s['slowdown']:.2f}x slower)")

            model.zero_grad()

    if rank == 0:
        normal_throughput = np.mean(throughputs['normal'])
        straggler_throughput = np.mean(throughputs['straggler'])
        impact = (normal_throughput - straggler_throughput) / normal_throughput

        print(f"\n=== Straggler Impact Analysis ===")
        print(f"Normal throughput: {normal_throughput:.1f} samples/sec")
        print(f"With straggler: {straggler_throughput:.1f} samples/sec")
        print(f"Throughput loss: {impact:.1%}")
        print(f"Slowdown factor: {normal_throughput/straggler_throughput:.2f}x")

def run_straggler_detection():
    dist.init_process_group('nccl')

    model = torch.nn.Sequential(
        *[torch.nn.Linear(4096, 4096) for _ in range(10)]
    ).cuda()

    simulate_straggler(model, straggler_rank=3, slowdown_ms=50)

# run_straggler_detection()

Expected results (8 GPUs, 50ms artificial delay on rank 3):

Metric Normal With Straggler
Mean step time 15 ms 65 ms
Throughput 66.7 samples/s 15.4 samples/s
Throughput loss - 77%

Key insight: A single slow GPU slows down the entire training because: - AllReduce requires all participants - Collective operations are synchronous - The slowest rank determines the step time

Mitigation strategies: 1. Load balancing across nodes 2. Excluding persistent stragglers 3. Asynchronous SGD (with convergence trade-offs)

\[\boxed{\text{50ms straggler on 1/8 GPUs causes } \sim77\% \text{ throughput loss}}\]
  1. Bucket Optimization: Profile DDP with different bucket sizes (1MB, 25MB, 100MB, 250MB). Measure step time for each. What's the optimal bucket size for your model?
Solution

Bucket size optimization:

import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import time

def benchmark_bucket_size(model_fn, bucket_sizes_mb, input_shape, num_warmup=10, num_measure=50):
    """Benchmark DDP with different bucket sizes."""

    results = []
    device = torch.device('cuda')

    for bucket_mb in bucket_sizes_mb:
        bucket_bytes = bucket_mb * 1024 * 1024

        # Create fresh model for each test
        model = model_fn().to(device)
        model = DDP(model, bucket_cap_mb=bucket_mb)

        input_data = torch.randn(*input_shape, device=device)

        # Warmup
        for _ in range(num_warmup):
            output = model(input_data)
            loss = output.sum()
            loss.backward()
            model.zero_grad()

        torch.cuda.synchronize()

        # Measure
        times = []
        for _ in range(num_measure):
            torch.cuda.synchronize()
            start = time.time()
            output = model(input_data)
            loss = output.sum()
            loss.backward()
            torch.cuda.synchronize()
            times.append(time.time() - start)
            model.zero_grad()

        median_time = torch.tensor(times).median().item()
        std_time = torch.tensor(times).std().item()

        results.append({
            'bucket_mb': bucket_mb,
            'median_ms': median_time * 1000,
            'std_ms': std_time * 1000
        })

        # Clean up
        del model
        torch.cuda.empty_cache()

    return results

def run_bucket_optimization():
    dist.init_process_group('nccl')
    rank = dist.get_rank()

    def create_model():
        return nn.Sequential(*[nn.Linear(4096, 4096) for _ in range(20)])

    bucket_sizes = [1, 5, 10, 25, 50, 100, 250]
    input_shape = (32, 4096)

    results = benchmark_bucket_size(create_model, bucket_sizes, input_shape)

    if rank == 0:
        print("\nBucket Size Optimization Results:")
        print("-" * 45)
        print(f"{'Bucket Size':>12} {'Median Time':>12} {'Std Dev':>10}")
        print("-" * 45)

        best_result = min(results, key=lambda x: x['median_ms'])

        for r in results:
            marker = " *" if r['bucket_mb'] == best_result['bucket_mb'] else ""
            print(f"{r['bucket_mb']:>10} MB {r['median_ms']:>10.2f} ms {r['std_ms']:>8.2f} ms{marker}")

        print("-" * 45)
        print(f"Optimal bucket size: {best_result['bucket_mb']} MB")

# run_bucket_optimization()

Expected results (20-layer MLP, 8 GPUs):

Bucket Size Step Time Notes
1 MB 28.5 ms High latency overhead
5 MB 22.1 ms
10 MB 19.8 ms
25 MB 18.2 ms Optimal
50 MB 18.5 ms
100 MB 19.1 ms Less overlap
250 MB 21.3 ms Poor overlap

Analysis:

  • Too small (1-5 MB): High latency overhead per bucket
  • Optimal (10-50 MB): Good balance of latency amortization and overlap opportunity
  • Too large (100+ MB): Less granular overlap, buckets complete after compute

The trade-off: - Smaller buckets → more AllReduce calls → higher latency overhead - Larger buckets → less overlap with computation

\[\boxed{\text{Optimal bucket size typically 10-50 MB; default 25 MB is reasonable}}\]
  1. Memory Profiling: Track memory usage through a training step. Identify the peak allocation point. What consumes the most memory: parameters, gradients, activations, or optimizer state?
Solution

Memory profiling implementation:

import torch
import torch.nn as nn
from torch.cuda import memory_allocated, max_memory_allocated, reset_peak_memory_stats

class MemoryProfiler:
    def __init__(self):
        self.checkpoints = []

    def checkpoint(self, name):
        """Record current memory usage."""
        torch.cuda.synchronize()
        self.checkpoints.append({
            'name': name,
            'allocated_mb': memory_allocated() / 1e6,
            'peak_mb': max_memory_allocated() / 1e6
        })

    def reset(self):
        self.checkpoints = []
        reset_peak_memory_stats()

    def report(self):
        """Print memory usage report."""
        print("\n=== Memory Profile ===")
        print(f"{'Checkpoint':<30} {'Allocated':>12} {'Peak':>12} {'Delta':>12}")
        print("-" * 70)

        prev_alloc = 0
        for cp in self.checkpoints:
            delta = cp['allocated_mb'] - prev_alloc
            print(f"{cp['name']:<30} {cp['allocated_mb']:>10.1f} MB {cp['peak_mb']:>10.1f} MB {delta:>+10.1f} MB")
            prev_alloc = cp['allocated_mb']

        print("-" * 70)
        print(f"{'Peak memory':>30} {max(c['peak_mb'] for c in self.checkpoints):>10.1f} MB")

def profile_training_step(model, optimizer, input_data, target):
    """Profile memory through a complete training step."""
    profiler = MemoryProfiler()
    profiler.reset()

    profiler.checkpoint("Initial")

    # Model parameters
    model = model.cuda()
    profiler.checkpoint("After model.cuda()")

    # Optimizer state (allocates momentum, variance buffers on first step)
    # We'll trigger this by doing one dummy step
    dummy_input = torch.randn_like(input_data).cuda()
    dummy_output = model(dummy_input)
    dummy_output.sum().backward()
    optimizer.step()
    optimizer.zero_grad()
    profiler.checkpoint("After optimizer init")

    # Fresh start for actual measurement
    torch.cuda.empty_cache()
    profiler.checkpoint("After cache clear")

    # Forward pass - activations allocated
    input_data = input_data.cuda()
    profiler.checkpoint("Input on GPU")

    output = model(input_data)
    profiler.checkpoint("After forward (activations)")

    # Loss computation
    loss = output.sum()
    profiler.checkpoint("After loss")

    # Backward pass - gradients allocated
    loss.backward()
    profiler.checkpoint("After backward (gradients)")

    # Optimizer step
    optimizer.step()
    profiler.checkpoint("After optimizer step")

    # Zero gradients
    optimizer.zero_grad()
    profiler.checkpoint("After zero_grad")

    profiler.report()

    return profiler.checkpoints

def analyze_memory_breakdown(model_params, batch_size, seq_len, hidden_dim, num_layers):
    """Theoretical memory breakdown."""
    bytes_per_param = 4  # FP32

    # Parameters
    param_memory = model_params * bytes_per_param

    # Gradients (same size as parameters)
    gradient_memory = model_params * bytes_per_param

    # Optimizer state (Adam: 2 states per parameter)
    optimizer_memory = model_params * bytes_per_param * 2

    # Activations (rough estimate for transformer)
    # Per layer: ~34 * B * S * H bytes
    activation_memory = num_layers * 34 * batch_size * seq_len * hidden_dim

    total = param_memory + gradient_memory + optimizer_memory + activation_memory

    breakdown = {
        'parameters': param_memory / 1e9,
        'gradients': gradient_memory / 1e9,
        'optimizer': optimizer_memory / 1e9,
        'activations': activation_memory / 1e9,
        'total': total / 1e9
    }

    print("\n=== Theoretical Memory Breakdown ===")
    print(f"{'Component':<20} {'Memory (GB)':>12} {'Percentage':>12}")
    print("-" * 50)
    for name, mem in breakdown.items():
        if name != 'total':
            pct = mem / breakdown['total'] * 100
            print(f"{name:<20} {mem:>10.2f} GB {pct:>10.1f}%")
    print("-" * 50)
    print(f"{'TOTAL':<20} {breakdown['total']:>10.2f} GB")

    return breakdown

def run_memory_profiling():
    # 7B parameter model approximation
    model = nn.Sequential(*[nn.Linear(4096, 4096) for _ in range(32)])

    optimizer = torch.optim.Adam(model.parameters())
    input_data = torch.randn(32, 4096)
    target = torch.randn(32, 4096)

    profile_training_step(model, optimizer, input_data, target)

    # Theoretical breakdown for comparison
    num_params = sum(p.numel() for p in model.parameters())
    analyze_memory_breakdown(
        model_params=num_params,
        batch_size=32,
        seq_len=2048,
        hidden_dim=4096,
        num_layers=32
    )

# run_memory_profiling()

Expected memory breakdown (7B model, batch=4, seq=2048):

Component Memory Percentage
Parameters 28 GB 16%
Gradients 28 GB 16%
Optimizer (Adam) 56 GB 32%
Activations 64 GB 36%
TOTAL 176 GB 100%

Peak occurs during backward pass when both: - All activations are still needed for gradient computation - Gradients are being allocated

Key insights: 1. Optimizer state dominates static memory (48% of non-activation) 2. Activations dominate dynamic memory and scale with batch size 3. Peak memory ≈ params + grads + optimizer + activations

\[\boxed{\text{Activations: 36\%, Optimizer: 32\%, Params+Grads: 32\%}}\]

Key Takeaways

  1. Measure, don't guess: Profiling transforms intuition into actionable data.

  2. Multiple tools for multiple purposes: Nsight for detailed traces, PyTorch Profiler for quick checks, NCCL debug for communication analysis.

  3. MFU is the ultimate metric: It captures how well you're using your hardware.

  4. The alpha-beta model predicts communication: Calibrate it for your cluster to understand bottlenecks.

  5. Overlap efficiency matters: The gap between compute+comm and max(compute, comm) is your opportunity.

  6. Stragglers kill scaling: One slow GPU affects all GPUs in synchronous training.

  7. Profile iteratively: Start broad, then zoom in on bottlenecks.