22 Mixture of Experts
Scaling with Conditional Computation
GPT-4 reportedly has over a trillion parameters. Yet it runs at reasonable speed. The secret: most parameters don’t activate for any given token.
Mixture of Experts achieves massive model capacity with sublinear compute cost.
This chapter is a case study in sparsity—the fourth property from our Algebraic Framework.
In MoE [1], the activation pattern is sparse: only k of N experts activate for each token. This isn’t an approximation—it’s the architecture itself. By designing for sparsity, MoE achieves massive capacity with constant compute.
Most scaling techniques accept the computation and try to do it faster. MoE asks: what if we simply skip most of the computation?
22.1 The Scaling Dilemma
Dense transformers face a fundamental tradeoff:
Model Size vs. Inference Cost
Dense model: Every parameter activates for every token
- 70B parameters → 70B activations per token
- 175B parameters → 175B activations per token
Scaling law: Larger models are better, but cost grows linearly.
What if we could have the capacity of a trillion-parameter model with the compute cost of a 70B model?
22.2 The MoE Insight
Mixture of Experts (MoE) replaces dense feedforward layers with a collection of “expert” networks, activating only a subset for each token.
Dense FFN: MoE Layer:
Input ──→ [FFN] ──→ Input ──→ [Router] ──→ Select top-k experts
│
┌───────┼───────┐
↓ ↓ ↓
[E1] [E2] [E3] ... [En]
│ │ │
└───────┼───────┘
↓
Combine
↓
Output
Key parameters: - N: Total number of experts (e.g., 8, 64, or 128) - k: Experts activated per token (typically 1 or 2) - Capacity factor: How many tokens each expert can handle
22.3 The Mathematics
22.3.1 Router (Gating Network)
The router decides which experts handle each token:
\[G(x) = \text{TopK}(\text{softmax}(W_g \cdot x), k)\]
For input \(x\) with dimension \(d\): - \(W_g \in \mathbb{R}^{N \times d}\) produces a score for each expert - TopK selects the \(k\) highest-scoring experts - Softmax normalizes to get routing weights
def router(x, W_g, k=2):
"""
Route tokens to top-k experts.
x: [batch, seq, d_model]
W_g: [num_experts, d_model]
Returns: expert_indices [batch, seq, k], weights [batch, seq, k]
"""
# Compute routing scores
scores = x @ W_g.T # [batch, seq, num_experts]
# Select top-k experts
weights, indices = torch.topk(scores, k, dim=-1)
weights = F.softmax(weights, dim=-1) # Normalize selected weights
return indices, weights22.3.2 Expert Computation
Each expert is typically a standard FFN:
\[E_i(x) = W_2^{(i)} \cdot \text{activation}(W_1^{(i)} \cdot x)\]
The final output combines selected experts:
\[y = \sum_{i \in \text{TopK}} G(x)_i \cdot E_i(x)\]
class MoELayer(nn.Module):
def __init__(self, d_model, d_ff, num_experts, top_k=2):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
# Router
self.gate = nn.Linear(d_model, num_experts, bias=False)
# Experts (each is a standard FFN)
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Linear(d_ff, d_model)
)
for _ in range(num_experts)
])
def forward(self, x):
batch, seq, d = x.shape
# Route tokens
scores = self.gate(x) # [batch, seq, num_experts]
weights, indices = torch.topk(scores, self.top_k, dim=-1)
weights = F.softmax(weights, dim=-1)
# Compute expert outputs (naive loop - see optimization below)
output = torch.zeros_like(x)
for i in range(self.top_k):
expert_idx = indices[:, :, i] # [batch, seq]
expert_weight = weights[:, :, i:i+1] # [batch, seq, 1]
for e in range(self.num_experts):
mask = (expert_idx == e)
if mask.any():
expert_input = x[mask]
expert_output = self.experts[e](expert_input)
output[mask] += expert_weight[mask] * expert_output
return output22.4 Why MoE Works: The Capacity-Compute Tradeoff
Mixtral 8x7B:
- Total parameters: 46.7B (8 experts × ~7B each, minus shared layers)
- Active parameters: ~12.9B per token (2 experts active)
- Effective capacity: Similar to 70B+ dense model
- Compute cost: Similar to 12B dense model
Result: 70B quality at 12B speed
The magic: specialization. Different experts learn different skills: - Some experts handle math - Some handle code - Some handle natural language - Router learns to dispatch appropriately
22.5 The Load Balancing Problem
22.5.1 Why Balance Matters
Naive routing often collapses—a few experts handle most tokens while others idle:
Unbalanced routing:
Expert 0: ████████████████████████████████ 80%
Expert 1: ███ 5%
Expert 2: ██ 3%
Expert 3: █ 2%
Expert 4: ██ 3%
Expert 5: █ 2%
Expert 6: ██ 3%
Expert 7: █ 2%
Problem:
- Expert 0 is a bottleneck
- Experts 1-7 waste capacity
- Effectively a dense model with extra overhead
22.5.2 Auxiliary Load Balancing Loss
The standard solution: add a loss term encouraging balanced routing.
def load_balancing_loss(router_probs, expert_indices, num_experts):
"""
Encourage uniform expert utilization.
router_probs: [batch, seq, num_experts] - softmax probabilities
expert_indices: [batch, seq, k] - selected expert indices
"""
# Fraction of tokens routed to each expert
# (based on hard assignments)
tokens_per_expert = torch.zeros(num_experts, device=router_probs.device)
for e in range(num_experts):
tokens_per_expert[e] = (expert_indices == e).float().sum()
tokens_per_expert = tokens_per_expert / expert_indices.numel()
# Average routing probability to each expert
# (based on soft probabilities)
prob_per_expert = router_probs.mean(dim=[0, 1])
# Load balancing loss: dot product of the two
# Minimized when both are uniform (1/num_experts)
lb_loss = num_experts * (tokens_per_expert * prob_per_expert).sum()
return lb_loss
# In training:
# total_loss = task_loss + 0.01 * load_balancing_loss(...)22.5.3 Capacity Factor
Another approach: limit how many tokens each expert can process.
def route_with_capacity(x, router_weights, capacity_factor=1.25):
"""
Route with expert capacity limits.
capacity_factor: How much overflow to allow (1.0 = exact, 1.25 = 25% buffer)
"""
batch, seq, num_experts = router_weights.shape
# Max tokens per expert
capacity = int(capacity_factor * seq / num_experts)
# Track how many tokens assigned to each expert
expert_counts = torch.zeros(num_experts, dtype=torch.long)
# Assign tokens (dropping overflow)
assignments = torch.full((batch, seq), -1) # -1 = dropped
for b in range(batch):
for s in range(seq):
# Get preferred expert order
prefs = router_weights[b, s].argsort(descending=True)
for expert in prefs:
if expert_counts[expert] < capacity:
assignments[b, s] = expert
expert_counts[expert] += 1
break
return assignmentsDropped tokens are a problem—they contribute no gradient. Solutions: - Higher capacity factor (wastes compute) - Token dropping auxiliary loss - Expert choice routing (experts select tokens, not vice versa)
22.6 Expert Parallelism
MoE introduces a new parallelism dimension: distributing experts across GPUs.
22.6.1 The All-to-All Pattern
Before All-to-All: After All-to-All:
GPU 0: [E0, E1] has tokens GPU 0: [E0, E1] has ALL tokens
for all experts destined for E0, E1
GPU 1: [E2, E3] has tokens GPU 1: [E2, E3] has ALL tokens
for all experts destined for E2, E3
GPU 2: [E4, E5] has tokens GPU 2: [E4, E5] has ALL tokens
for all experts destined for E4, E5
GPU 3: [E6, E7] has tokens GPU 3: [E6, E7] has ALL tokens
for all experts destined for E6, E7
Communication pattern: Each GPU sends tokens to all other GPUs based on routing decisions.
def expert_parallel_forward(x, router, experts, world_size, rank):
"""
MoE forward with expert parallelism.
Each GPU holds a subset of experts.
"""
# 1. Compute routing on each GPU
indices, weights = router(x) # Each GPU routes its tokens
# 2. All-to-all: send tokens to GPUs holding their experts
# tokens_to_send[dst_rank] = tokens that need experts on dst_rank
tokens_to_send = [[] for _ in range(world_size)]
for token_idx, expert_idx in enumerate(indices.flatten()):
dst_rank = expert_idx // experts_per_gpu
tokens_to_send[dst_rank].append((token_idx, x.flatten(0,1)[token_idx]))
# Execute all-to-all (NCCL)
received_tokens = all_to_all(tokens_to_send)
# 3. Process tokens through local experts
local_outputs = []
for token_idx, token, expert_local_idx in received_tokens:
output = experts[expert_local_idx](token)
local_outputs.append((token_idx, output))
# 4. All-to-all: return outputs to original GPUs
returned_outputs = all_to_all(local_outputs)
# 5. Combine with routing weights
output = combine_expert_outputs(returned_outputs, weights)
return output22.6.2 Communication Cost Analysis
All-to-all communication scales poorly:
N GPUs, B tokens per GPU, D dimensions:
All-to-all sends: Each GPU sends B*D/N to each other GPU
Total bandwidth: B * D * (N-1) / N ≈ B * D per GPU
For 8 GPUs, 4096 tokens, 4096 dimensions, FP16:
= 4096 * 4096 * 2 bytes = 32 MB per GPU
At 200 GB/s interconnect: 0.16 ms latency
But with 64 GPUs: still 32 MB but more coordination overhead
Key insight: All-to-all doesn’t scale as badly as all-reduce (which sends full model), but coordination overhead grows.
22.7 Megablocks: Efficient MoE Kernels
The naive MoE implementation is slow due to: 1. Irregular memory access (tokens scattered to different experts) 2. Small, variable-sized expert batches 3. Multiple kernel launches
Megablocks solution: Block-sparse operations.
# Conceptual Megablocks approach:
def megablocks_moe(x, experts, routing):
"""
Efficient MoE using block-sparse operations.
"""
# 1. Sort tokens by expert assignment
# Now tokens for same expert are contiguous
sorted_tokens, sort_indices = sort_by_expert(x, routing)
# 2. Compute expert boundaries
# expert_boundaries[i] = start index for expert i
boundaries = compute_boundaries(routing)
# 3. Block-sparse matmul
# Single kernel handles all experts with variable batch sizes
# Uses block-sparse format for efficiency
output = block_sparse_ffn(sorted_tokens, experts, boundaries)
# 4. Unsort to restore original order
output = unsort(output, sort_indices)
return outputPerformance: 2-4× faster than naive implementation through better memory access patterns.
22.8 MoE Architectures in Practice
22.8.1 Mixtral 8x7B
Architecture:
- 8 experts per MoE layer
- Top-2 routing (2 experts per token)
- 32 MoE layers (every FFN is MoE)
- Experts share attention parameters
- Each expert: 7B FFN parameters
Total: 46.7B parameters
Active: 12.9B parameters per token
22.8.2 DeepSeek-MoE
Innovations:
- Fine-grained experts (more, smaller experts)
- Shared experts (some experts always active)
- 64 routed experts + 2 shared experts
- Top-6 routing from routed experts
Result: Better specialization with shared common knowledge
22.8.3 Switch Transformer
Simplification:
- Top-1 routing (only 1 expert per token)
- Simpler, faster, but potentially lower quality
- Works well at very large scale
22.9 MoE for Inference
Training MoE is one challenge; serving is another.
22.9.1 The Memory Problem
All experts must be in memory, even though only k activate:
Mixtral 8x7B inference:
- Must load: 46.7B parameters
- Actually use: 12.9B per forward pass
Memory: Same as dense 46.7B model
Compute: Same as dense 12.9B model
Problem: Memory-bound inference gains less from MoE
22.9.2 Expert Offloading
For memory-constrained serving, offload inactive experts:
class OffloadedMoE:
def __init__(self, experts, device='cuda'):
# Keep router on GPU
self.router = router.to(device)
# Keep experts on CPU, move on-demand
self.experts = [e.to('cpu') for e in experts]
self.device = device
# LRU cache for recently-used experts
self.expert_cache = LRUCache(max_size=4)
def forward(self, x):
# Route
indices, weights = self.router(x)
needed_experts = indices.unique().tolist()
# Load needed experts
for e in needed_experts:
if e not in self.expert_cache:
# Move to GPU
self.experts[e] = self.experts[e].to(self.device)
self.expert_cache.put(e, self.experts[e])
# Maybe evict old expert
evicted = self.expert_cache.evict_if_needed()
if evicted:
self.experts[evicted] = self.experts[evicted].to('cpu')
# Compute
return moe_forward(x, self.experts, indices, weights)Latency impact: PCIe transfer adds ~1-2ms per expert load. Predictable routing helps (cache warming).
22.9.3 Expert Parallelism for Inference
Unlike training, inference often needs low latency over high throughput:
Strategy 1: Replicate all experts
- Each GPU has all experts
- No all-to-all needed
- Memory: Full model per GPU
Strategy 2: Expert parallel + tensor parallel
- Experts sharded across GPUs
- Attention tensor-parallel within GPU groups
- Communication: All-to-all for MoE, all-reduce for attention
Strategy 3: Expert caching with speculation
- Predict likely experts
- Pre-fetch to GPU before needed
- Works when routing is predictable
22.10 Training MoE Models
22.10.1 Stability Challenges
MoE training is less stable than dense models:
- Router collapse: Router learns to always select same experts
- Expert imbalance: Some experts undertrained
- Gradient variance: Sparse activation = high gradient variance
Solutions:
# 1. Load balancing loss (mandatory)
loss = task_loss + 0.01 * lb_loss
# 2. Router z-loss (stabilizes softmax)
z_loss = torch.logsumexp(router_logits, dim=-1).pow(2).mean()
loss += 0.001 * z_loss
# 3. Jitter/noise in routing (exploration)
if training:
router_logits += torch.randn_like(router_logits) * 0.122.10.2 Distributed Training Setup
Typical large MoE training uses multiple parallelism dimensions:
Example: Training Mixtral-scale model on 128 GPUs
Expert Parallel: 8-way (8 GPUs share experts)
Data Parallel: 16-way (16 replicas)
8 × 16 = 128 GPUs
Communication:
- All-to-all within expert parallel groups
- All-reduce gradients across data parallel groups
22.11 When to Use MoE
22.11.1 MoE Advantages
- More capacity per FLOP: 3-4× more parameters at same compute
- Specialization: Experts develop distinct capabilities
- Graceful scaling: Add experts without proportional compute increase
22.11.2 MoE Disadvantages
- Memory: All parameters must be accessible (can’t shard by layer easily)
- Communication: All-to-all overhead in distributed setting
- Complexity: Routing, load balancing, stability
- Batch sensitivity: Small batches underutilize experts
22.11.3 Decision Framework
Use MoE when:
✓ Model capacity matters more than inference memory
✓ Training compute is the bottleneck (not serving)
✓ Batch sizes are large enough for load balancing
✓ Task benefits from specialization
Use Dense when:
✓ Serving many small requests (memory-bound)
✓ Simple deployment is priority
✓ Small-scale training
✓ Latency-critical applications
22.12 Connections
Skipping/Sparsity: MoE is structured sparsity—most parameters are zero for any given input.
Parallelism: Expert parallelism is a new dimension beyond data/tensor/pipeline parallel.
Distributed: All-to-all communication is MoE’s unique distributed primitive.
Inference: MoE inference has unique challenges around expert caching and memory.
22.13 Key Takeaways
MoE decouples capacity from compute: More parameters without proportional FLOP increase.
Load balancing is critical: Without it, MoE collapses to dense model plus overhead.
All-to-all communication: The unique distributed pattern for MoE.
Inference is memory-bound: MoE saves training compute but not inference memory.
Specialization emerges: Experts naturally learn different capabilities.
Frontier models are MoE: GPT-4, Mixtral, DeepSeek—understanding MoE is essential.
The accompanying notebook lets you:
- Implement a simple MoE layer from scratch
- Visualize expert routing and specialization
- Experiment with load balancing losses
- Measure the capacity-compute tradeoff
Notebook support for this chapter is in progress. For now, prototype these experiments locally and compare routing/load-balancing behavior on your setup.
22.14 Further Reading
- Shazeer et al. (2017). “Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer”
- Fedus et al. (2022). “Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity”
- Jiang et al. (2024). “Mixtral of Experts” - Mistral AI technical report
- Gale et al. (2023). “Megablocks: Efficient Sparse Training with Mixture-of-Experts”
- DeepSeek-AI (2024). “DeepSeek-MoE: Towards Ultimate Expert Specialization”