22  Mixture of Experts

Scaling with Conditional Computation

GPT-4 reportedly has over a trillion parameters. Yet it runs at reasonable speed. The secret: most parameters don’t activate for any given token.

Mixture of Experts achieves massive model capacity with sublinear compute cost.

NoteProperty Spotlight: Sparsity

This chapter is a case study in sparsity—the fourth property from our Algebraic Framework.

In MoE [1], the activation pattern is sparse: only k of N experts activate for each token. This isn’t an approximation—it’s the architecture itself. By designing for sparsity, MoE achieves massive capacity with constant compute.

Most scaling techniques accept the computation and try to do it faster. MoE asks: what if we simply skip most of the computation?

22.1 The Scaling Dilemma

Dense transformers face a fundamental tradeoff:

Model Size vs. Inference Cost

Dense model: Every parameter activates for every token
  - 70B parameters → 70B activations per token
  - 175B parameters → 175B activations per token

Scaling law: Larger models are better, but cost grows linearly.

What if we could have the capacity of a trillion-parameter model with the compute cost of a 70B model?

22.2 The MoE Insight

Mixture of Experts (MoE) replaces dense feedforward layers with a collection of “expert” networks, activating only a subset for each token.

Dense FFN:              MoE Layer:

Input ──→ [FFN] ──→     Input ──→ [Router] ──→ Select top-k experts
                                      │
                              ┌───────┼───────┐
                              ↓       ↓       ↓
                           [E1]    [E2]    [E3] ... [En]
                              │       │       │
                              └───────┼───────┘
                                      ↓
                                   Combine
                                      ↓
                                   Output

Key parameters: - N: Total number of experts (e.g., 8, 64, or 128) - k: Experts activated per token (typically 1 or 2) - Capacity factor: How many tokens each expert can handle

22.3 The Mathematics

22.3.1 Router (Gating Network)

The router decides which experts handle each token:

\[G(x) = \text{TopK}(\text{softmax}(W_g \cdot x), k)\]

For input \(x\) with dimension \(d\): - \(W_g \in \mathbb{R}^{N \times d}\) produces a score for each expert - TopK selects the \(k\) highest-scoring experts - Softmax normalizes to get routing weights

def router(x, W_g, k=2):
    """
    Route tokens to top-k experts.

    x: [batch, seq, d_model]
    W_g: [num_experts, d_model]
    Returns: expert_indices [batch, seq, k], weights [batch, seq, k]
    """
    # Compute routing scores
    scores = x @ W_g.T  # [batch, seq, num_experts]

    # Select top-k experts
    weights, indices = torch.topk(scores, k, dim=-1)
    weights = F.softmax(weights, dim=-1)  # Normalize selected weights

    return indices, weights

22.3.2 Expert Computation

Each expert is typically a standard FFN:

\[E_i(x) = W_2^{(i)} \cdot \text{activation}(W_1^{(i)} \cdot x)\]

The final output combines selected experts:

\[y = \sum_{i \in \text{TopK}} G(x)_i \cdot E_i(x)\]

class MoELayer(nn.Module):
    def __init__(self, d_model, d_ff, num_experts, top_k=2):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k

        # Router
        self.gate = nn.Linear(d_model, num_experts, bias=False)

        # Experts (each is a standard FFN)
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, d_ff),
                nn.GELU(),
                nn.Linear(d_ff, d_model)
            )
            for _ in range(num_experts)
        ])

    def forward(self, x):
        batch, seq, d = x.shape

        # Route tokens
        scores = self.gate(x)  # [batch, seq, num_experts]
        weights, indices = torch.topk(scores, self.top_k, dim=-1)
        weights = F.softmax(weights, dim=-1)

        # Compute expert outputs (naive loop - see optimization below)
        output = torch.zeros_like(x)
        for i in range(self.top_k):
            expert_idx = indices[:, :, i]  # [batch, seq]
            expert_weight = weights[:, :, i:i+1]  # [batch, seq, 1]

            for e in range(self.num_experts):
                mask = (expert_idx == e)
                if mask.any():
                    expert_input = x[mask]
                    expert_output = self.experts[e](expert_input)
                    output[mask] += expert_weight[mask] * expert_output

        return output

22.4 Why MoE Works: The Capacity-Compute Tradeoff

Mixtral 8x7B:
  - Total parameters: 46.7B (8 experts × ~7B each, minus shared layers)
  - Active parameters: ~12.9B per token (2 experts active)
  - Effective capacity: Similar to 70B+ dense model
  - Compute cost: Similar to 12B dense model

Result: 70B quality at 12B speed

The magic: specialization. Different experts learn different skills: - Some experts handle math - Some handle code - Some handle natural language - Router learns to dispatch appropriately

22.5 The Load Balancing Problem

22.5.1 Why Balance Matters

Naive routing often collapses—a few experts handle most tokens while others idle:

Unbalanced routing:

Expert 0: ████████████████████████████████ 80%
Expert 1: ███ 5%
Expert 2: ██ 3%
Expert 3: █ 2%
Expert 4: ██ 3%
Expert 5: █ 2%
Expert 6: ██ 3%
Expert 7: █ 2%

Problem:
- Expert 0 is a bottleneck
- Experts 1-7 waste capacity
- Effectively a dense model with extra overhead

22.5.2 Auxiliary Load Balancing Loss

The standard solution: add a loss term encouraging balanced routing.

def load_balancing_loss(router_probs, expert_indices, num_experts):
    """
    Encourage uniform expert utilization.

    router_probs: [batch, seq, num_experts] - softmax probabilities
    expert_indices: [batch, seq, k] - selected expert indices
    """
    # Fraction of tokens routed to each expert
    # (based on hard assignments)
    tokens_per_expert = torch.zeros(num_experts, device=router_probs.device)
    for e in range(num_experts):
        tokens_per_expert[e] = (expert_indices == e).float().sum()
    tokens_per_expert = tokens_per_expert / expert_indices.numel()

    # Average routing probability to each expert
    # (based on soft probabilities)
    prob_per_expert = router_probs.mean(dim=[0, 1])

    # Load balancing loss: dot product of the two
    # Minimized when both are uniform (1/num_experts)
    lb_loss = num_experts * (tokens_per_expert * prob_per_expert).sum()

    return lb_loss

# In training:
# total_loss = task_loss + 0.01 * load_balancing_loss(...)

22.5.3 Capacity Factor

Another approach: limit how many tokens each expert can process.

def route_with_capacity(x, router_weights, capacity_factor=1.25):
    """
    Route with expert capacity limits.

    capacity_factor: How much overflow to allow (1.0 = exact, 1.25 = 25% buffer)
    """
    batch, seq, num_experts = router_weights.shape

    # Max tokens per expert
    capacity = int(capacity_factor * seq / num_experts)

    # Track how many tokens assigned to each expert
    expert_counts = torch.zeros(num_experts, dtype=torch.long)

    # Assign tokens (dropping overflow)
    assignments = torch.full((batch, seq), -1)  # -1 = dropped

    for b in range(batch):
        for s in range(seq):
            # Get preferred expert order
            prefs = router_weights[b, s].argsort(descending=True)

            for expert in prefs:
                if expert_counts[expert] < capacity:
                    assignments[b, s] = expert
                    expert_counts[expert] += 1
                    break

    return assignments

Dropped tokens are a problem—they contribute no gradient. Solutions: - Higher capacity factor (wastes compute) - Token dropping auxiliary loss - Expert choice routing (experts select tokens, not vice versa)

22.6 Expert Parallelism

MoE introduces a new parallelism dimension: distributing experts across GPUs.

22.6.1 The All-to-All Pattern

Before All-to-All:                After All-to-All:

GPU 0: [E0, E1] has tokens        GPU 0: [E0, E1] has ALL tokens
       for all experts                   destined for E0, E1

GPU 1: [E2, E3] has tokens        GPU 1: [E2, E3] has ALL tokens
       for all experts                   destined for E2, E3

GPU 2: [E4, E5] has tokens        GPU 2: [E4, E5] has ALL tokens
       for all experts                   destined for E4, E5

GPU 3: [E6, E7] has tokens        GPU 3: [E6, E7] has ALL tokens
       for all experts                   destined for E6, E7

Communication pattern: Each GPU sends tokens to all other GPUs based on routing decisions.

def expert_parallel_forward(x, router, experts, world_size, rank):
    """
    MoE forward with expert parallelism.

    Each GPU holds a subset of experts.
    """
    # 1. Compute routing on each GPU
    indices, weights = router(x)  # Each GPU routes its tokens

    # 2. All-to-all: send tokens to GPUs holding their experts
    # tokens_to_send[dst_rank] = tokens that need experts on dst_rank
    tokens_to_send = [[] for _ in range(world_size)]

    for token_idx, expert_idx in enumerate(indices.flatten()):
        dst_rank = expert_idx // experts_per_gpu
        tokens_to_send[dst_rank].append((token_idx, x.flatten(0,1)[token_idx]))

    # Execute all-to-all (NCCL)
    received_tokens = all_to_all(tokens_to_send)

    # 3. Process tokens through local experts
    local_outputs = []
    for token_idx, token, expert_local_idx in received_tokens:
        output = experts[expert_local_idx](token)
        local_outputs.append((token_idx, output))

    # 4. All-to-all: return outputs to original GPUs
    returned_outputs = all_to_all(local_outputs)

    # 5. Combine with routing weights
    output = combine_expert_outputs(returned_outputs, weights)

    return output

22.6.2 Communication Cost Analysis

All-to-all communication scales poorly:

N GPUs, B tokens per GPU, D dimensions:

All-to-all sends: Each GPU sends B*D/N to each other GPU
Total bandwidth: B * D * (N-1) / N ≈ B * D per GPU

For 8 GPUs, 4096 tokens, 4096 dimensions, FP16:
  = 4096 * 4096 * 2 bytes = 32 MB per GPU

At 200 GB/s interconnect: 0.16 ms latency
But with 64 GPUs: still 32 MB but more coordination overhead

Key insight: All-to-all doesn’t scale as badly as all-reduce (which sends full model), but coordination overhead grows.

22.7 Megablocks: Efficient MoE Kernels

The naive MoE implementation is slow due to: 1. Irregular memory access (tokens scattered to different experts) 2. Small, variable-sized expert batches 3. Multiple kernel launches

Megablocks solution: Block-sparse operations.

# Conceptual Megablocks approach:

def megablocks_moe(x, experts, routing):
    """
    Efficient MoE using block-sparse operations.
    """
    # 1. Sort tokens by expert assignment
    # Now tokens for same expert are contiguous
    sorted_tokens, sort_indices = sort_by_expert(x, routing)

    # 2. Compute expert boundaries
    # expert_boundaries[i] = start index for expert i
    boundaries = compute_boundaries(routing)

    # 3. Block-sparse matmul
    # Single kernel handles all experts with variable batch sizes
    # Uses block-sparse format for efficiency
    output = block_sparse_ffn(sorted_tokens, experts, boundaries)

    # 4. Unsort to restore original order
    output = unsort(output, sort_indices)

    return output

Performance: 2-4× faster than naive implementation through better memory access patterns.

22.8 MoE Architectures in Practice

22.8.1 Mixtral 8x7B

Architecture:
- 8 experts per MoE layer
- Top-2 routing (2 experts per token)
- 32 MoE layers (every FFN is MoE)
- Experts share attention parameters
- Each expert: 7B FFN parameters

Total: 46.7B parameters
Active: 12.9B parameters per token

22.8.2 DeepSeek-MoE

Innovations:
- Fine-grained experts (more, smaller experts)
- Shared experts (some experts always active)
- 64 routed experts + 2 shared experts
- Top-6 routing from routed experts

Result: Better specialization with shared common knowledge

22.8.3 Switch Transformer

Simplification:
- Top-1 routing (only 1 expert per token)
- Simpler, faster, but potentially lower quality
- Works well at very large scale

22.9 MoE for Inference

Training MoE is one challenge; serving is another.

22.9.1 The Memory Problem

All experts must be in memory, even though only k activate:

Mixtral 8x7B inference:
  - Must load: 46.7B parameters
  - Actually use: 12.9B per forward pass

Memory: Same as dense 46.7B model
Compute: Same as dense 12.9B model

Problem: Memory-bound inference gains less from MoE

22.9.2 Expert Offloading

For memory-constrained serving, offload inactive experts:

class OffloadedMoE:
    def __init__(self, experts, device='cuda'):
        # Keep router on GPU
        self.router = router.to(device)

        # Keep experts on CPU, move on-demand
        self.experts = [e.to('cpu') for e in experts]
        self.device = device

        # LRU cache for recently-used experts
        self.expert_cache = LRUCache(max_size=4)

    def forward(self, x):
        # Route
        indices, weights = self.router(x)
        needed_experts = indices.unique().tolist()

        # Load needed experts
        for e in needed_experts:
            if e not in self.expert_cache:
                # Move to GPU
                self.experts[e] = self.experts[e].to(self.device)
                self.expert_cache.put(e, self.experts[e])

                # Maybe evict old expert
                evicted = self.expert_cache.evict_if_needed()
                if evicted:
                    self.experts[evicted] = self.experts[evicted].to('cpu')

        # Compute
        return moe_forward(x, self.experts, indices, weights)

Latency impact: PCIe transfer adds ~1-2ms per expert load. Predictable routing helps (cache warming).

22.9.3 Expert Parallelism for Inference

Unlike training, inference often needs low latency over high throughput:

Strategy 1: Replicate all experts
  - Each GPU has all experts
  - No all-to-all needed
  - Memory: Full model per GPU

Strategy 2: Expert parallel + tensor parallel
  - Experts sharded across GPUs
  - Attention tensor-parallel within GPU groups
  - Communication: All-to-all for MoE, all-reduce for attention

Strategy 3: Expert caching with speculation
  - Predict likely experts
  - Pre-fetch to GPU before needed
  - Works when routing is predictable

22.10 Training MoE Models

22.10.1 Stability Challenges

MoE training is less stable than dense models:

  1. Router collapse: Router learns to always select same experts
  2. Expert imbalance: Some experts undertrained
  3. Gradient variance: Sparse activation = high gradient variance

Solutions:

# 1. Load balancing loss (mandatory)
loss = task_loss + 0.01 * lb_loss

# 2. Router z-loss (stabilizes softmax)
z_loss = torch.logsumexp(router_logits, dim=-1).pow(2).mean()
loss += 0.001 * z_loss

# 3. Jitter/noise in routing (exploration)
if training:
    router_logits += torch.randn_like(router_logits) * 0.1

22.10.2 Distributed Training Setup

Typical large MoE training uses multiple parallelism dimensions:

Example: Training Mixtral-scale model on 128 GPUs

Expert Parallel: 8-way (8 GPUs share experts)
Data Parallel: 16-way (16 replicas)
8 × 16 = 128 GPUs

Communication:
- All-to-all within expert parallel groups
- All-reduce gradients across data parallel groups

22.11 When to Use MoE

22.11.1 MoE Advantages

  1. More capacity per FLOP: 3-4× more parameters at same compute
  2. Specialization: Experts develop distinct capabilities
  3. Graceful scaling: Add experts without proportional compute increase

22.11.2 MoE Disadvantages

  1. Memory: All parameters must be accessible (can’t shard by layer easily)
  2. Communication: All-to-all overhead in distributed setting
  3. Complexity: Routing, load balancing, stability
  4. Batch sensitivity: Small batches underutilize experts

22.11.3 Decision Framework

Use MoE when:
  ✓ Model capacity matters more than inference memory
  ✓ Training compute is the bottleneck (not serving)
  ✓ Batch sizes are large enough for load balancing
  ✓ Task benefits from specialization

Use Dense when:
  ✓ Serving many small requests (memory-bound)
  ✓ Simple deployment is priority
  ✓ Small-scale training
  ✓ Latency-critical applications

22.12 Connections

Skipping/Sparsity: MoE is structured sparsity—most parameters are zero for any given input.

Parallelism: Expert parallelism is a new dimension beyond data/tensor/pipeline parallel.

Distributed: All-to-all communication is MoE’s unique distributed primitive.

Inference: MoE inference has unique challenges around expert caching and memory.

22.13 Key Takeaways

  1. MoE decouples capacity from compute: More parameters without proportional FLOP increase.

  2. Load balancing is critical: Without it, MoE collapses to dense model plus overhead.

  3. All-to-all communication: The unique distributed pattern for MoE.

  4. Inference is memory-bound: MoE saves training compute but not inference memory.

  5. Specialization emerges: Experts naturally learn different capabilities.

  6. Frontier models are MoE: GPT-4, Mixtral, DeepSeek—understanding MoE is essential.

NoteTry It Yourself

The accompanying notebook lets you:

  • Implement a simple MoE layer from scratch
  • Visualize expert routing and specialization
  • Experiment with load balancing losses
  • Measure the capacity-compute tradeoff

Notebook support for this chapter is in progress. For now, prototype these experiments locally and compare routing/load-balancing behavior on your setup.

22.14 Further Reading

  • Shazeer et al. (2017). “Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer”
  • Fedus et al. (2022). “Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity”
  • Jiang et al. (2024). “Mixtral of Experts” - Mistral AI technical report
  • Gale et al. (2023). “Megablocks: Efficient Sparse Training with Mixture-of-Experts”
  • DeepSeek-AI (2024). “DeepSeek-MoE: Towards Ultimate Expert Specialization”
[1]
N. Shazeer et al., “Outrageously large neural networks: The sparsely-gated mixture-of-experts layer,” arXiv preprint arXiv:1701.06538, 2017.