Tensor Parallelism from Linearity
Matrix multiplication is linear: \(f(aX) = af(X)\) and \(f(X + Y) = f(X) + f(Y)\). This single property enables tensor parallelism. Non-linear operations like GELU break this property, but element-wise ops can still be local; only operations that couple sharded dimensions force synchronization. Understanding linearity reveals which operations can be parallelized and which force communication.
The Question: We want to shard a linear layer \(Y = XW\) across 8 GPUs. Can we split \(W\) column-wise? Row-wise? What about the bias? What about LayerNorm? What about GELU?
Chapter Map
Prerequisites: Chapter 11 (AllReduce, AllGather), Chapter 14 (data parallelism fundamentals)
Key insight: Linear operations (matrix multiplications) can be partitioned across GPUs with a final AllReduce. Non-linear operations (GELU, LayerNorm, Softmax) may force synchronization depending on which dimension is sharded—understanding which operations couple shards determines your communication pattern.
The Linearity Property¶
Connection to The Algebra of Speed
Linearity is one of the core algebraic properties explored in the companion book The Algebra of Speed, where it enables single-machine optimizations like BLAS-level tiling and vectorization. Here the same property enables a fundamentally different optimization: splitting a single operation across multiple devices. The algebra is the same; the payoff is distributed.
Definition¶
A function \(f: V \to W\) between vector spaces is linear if for all vectors \(X, Y \in V\) and scalars \(a, b\):
This is equivalent to two conditions:
- Additivity: \(f(X + Y) = f(X) + f(Y)\)
- Homogeneity: \(f(aX) = af(X)\)
Why Linearity Enables Parallelism¶
If \(f\) is linear and we partition the input \(X = X_1 + X_2\):
We can compute \(f(X_1)\) and \(f(X_2)\) independently, then sum the results.
For matrix multiplication \(f(X) = XW\):
This decomposition is the foundation of tensor parallelism.
The Classification of Operations¶
| Operation | Linear? | Parallelizable? |
|---|---|---|
| Matrix multiply | Yes | Yes (with structure) |
| Bias addition | Affine (linear + translation) | Special handling |
| ReLU | No | Element-wise only |
| GELU | No | Element-wise only |
| Softmax | No | Requires full tensor along the softmax axis |
| LayerNorm | No | Requires statistics |
| Dropout | No (stochastic) | Element-wise with care |
Column-Parallel Linear Layers¶
The Idea¶
For a linear layer \(Y = XW + b\) with \(W \in \mathbb{R}^{d_{in} \times d_{out}}\):
Split \(W\) along columns (output dimension):
where each \(W_i \in \mathbb{R}^{d_{in} \times (d_{out}/P)}\).
The Computation¶
Each GPU \(i\) holds:
- Full input \(X\) (replicated)
- Shard \(W_i\) of weights
- Shard \(b_i\) of bias
Computes:
The results are column-partitioned:
Why It Works¶
Matrix multiplication distributes over column concatenation:
Proof:
Let \(X \in \mathbb{R}^{m \times n}\), \(W_1 \in \mathbb{R}^{n \times k_1}\), \(W_2 \in \mathbb{R}^{n \times k_2}\).
The \((i, j)\) entry of \(XW_1\) for \(j \leq k_1\):
The \((i, j)\) entry of \(X[W_1 | W_2]\) for \(j \leq k_1\):
Identical. \(\square\)
Communication¶
Forward pass: No communication needed. Each GPU computes independently.
Backward pass: Gradient w.r.t. \(X\) requires AllReduce:
Diagram¶
The following diagram shows how the weight matrix is split column-wise across GPUs:
flowchart TB
subgraph input["Input (Replicated)"]
X["X<br/>[batch × seq × d_in]"]
end
subgraph weights["Weight Matrix W split by columns"]
direction LR
W0["W₀<br/>[d_in × d/4]"]
W1["W₁<br/>[d_in × d/4]"]
W2["W₂<br/>[d_in × d/4]"]
W3["W₃<br/>[d_in × d/4]"]
end
subgraph outputs["Output (Column-partitioned)"]
direction LR
Y0["Y₀ = XW₀"]
Y1["Y₁ = XW₁"]
Y2["Y₂ = XW₂"]
Y3["Y₃ = XW₃"]
end
X --> W0 & W1 & W2 & W3
W0 --> Y0
W1 --> Y1
W2 --> Y2
W3 --> Y3
style W0 fill:#3498db,stroke:#2980b9,color:white
style W1 fill:#e74c3c,stroke:#c0392b,color:white
style W2 fill:#2ecc71,stroke:#27ae60,color:white
style W3 fill:#f39c12,stroke:#d68910,color:white
style Y0 fill:#3498db,stroke:#2980b9,color:white
style Y1 fill:#e74c3c,stroke:#c0392b,color:white
style Y2 fill:#2ecc71,stroke:#27ae60,color:white
style Y3 fill:#f39c12,stroke:#d68910,color:white
No communication in forward pass — each GPU computes its output shard independently.
Row-Parallel Linear Layers¶
The Idea¶
Split \(W\) along rows (input dimension):
where each \(W_i \in \mathbb{R}^{(d_{in}/P) \times d_{out}}\).
The Computation¶
Requires input split along columns:
Each GPU \(i\) holds:
- Shard \(X_i\) of input
- Shard \(W_i\) of weights
Computes partial result:
Note: \(X_i \in \mathbb{R}^{m \times (d_{in}/P)}\) and \(W_i \in \mathbb{R}^{(d_{in}/P) \times d_{out}}\), so \(Z_i \in \mathbb{R}^{m \times d_{out}}\).
Why AllReduce Is Needed¶
The full result requires summing:
Each GPU has a partial sum \(Z_i\). AllReduce computes \(\sum_i Z_i\) on all GPUs.
Proof:
For partitioned multiplication:
The \((i, j)\) entry:
Diagram¶
The following diagram shows row-parallel computation with AllReduce:
flowchart TB
subgraph input["Input X (Column-partitioned)"]
direction LR
X0["X₀"]
X1["X₁"]
X2["X₂"]
X3["X₃"]
end
subgraph weights["Weight Matrix W split by rows"]
direction LR
W0["W₀<br/>[d/4 × d_out]"]
W1["W₁<br/>[d/4 × d_out]"]
W2["W₂<br/>[d/4 × d_out]"]
W3["W₃<br/>[d/4 × d_out]"]
end
subgraph partial["Partial Results"]
direction LR
Z0["Z₀ = X₀W₀"]
Z1["Z₁ = X₁W₁"]
Z2["Z₂ = X₂W₂"]
Z3["Z₃ = X₃W₃"]
end
AR["AllReduce<br/>(Sum)"]
subgraph output["Output Y (Replicated)"]
Y["Y = Z₀ + Z₁ + Z₂ + Z₃"]
end
X0 --> W0 --> Z0
X1 --> W1 --> Z1
X2 --> W2 --> Z2
X3 --> W3 --> Z3
Z0 & Z1 & Z2 & Z3 --> AR --> Y
style W0 fill:#3498db,stroke:#2980b9,color:white
style W1 fill:#e74c3c,stroke:#c0392b,color:white
style W2 fill:#2ecc71,stroke:#27ae60,color:white
style W3 fill:#f39c12,stroke:#d68910,color:white
style X0 fill:#3498db,stroke:#2980b9,color:white
style X1 fill:#e74c3c,stroke:#c0392b,color:white
style X2 fill:#2ecc71,stroke:#27ae60,color:white
style X3 fill:#f39c12,stroke:#d68910,color:white
style AR fill:#9b59b6,stroke:#8e44ad,color:white
AllReduce required — partial results must be summed across all GPUs.
Bias Handling¶
For row-parallel with bias \(Y = XW + b\):
- Add bias after AllReduce (on the full result)
- Or: add \(b/P\) on each GPU before AllReduce (works due to sum)
The Megatron-LM Pattern¶
Shoeybi et al. (2019) introduced an elegant pattern combining column and row parallelism.
MLP Block¶
Standard Transformer MLP:
With Megatron parallelism:
flowchart TB
X["Input X<br/>(replicated)"]
subgraph colpar["Column-Parallel (no comm)"]
W1["W₁ sharded<br/>by columns"]
end
Y["Y = XW₁<br/>(column-sharded)"]
GELU["GeLU(Y)<br/>(local, element-wise)"]
subgraph rowpar["Row-Parallel"]
W2["W₂ sharded<br/>by rows"]
end
AR["AllReduce"]
Z["Output Z<br/>(replicated)"]
X --> colpar --> Y --> GELU --> rowpar --> AR --> Z
style colpar fill:#2ecc71,stroke:#27ae60,color:white
style rowpar fill:#3498db,stroke:#2980b9,color:white
style GELU fill:#f39c12,stroke:#d68910,color:white
style AR fill:#9b59b6,stroke:#8e44ad,color:white
Key insight: Column-parallel produces column-sharded output, which is exactly what row-parallel needs as input!
Why GELU Doesn't Break This¶
GELU is applied element-wise. Each element of \(Y\) is computed by one GPU.
Even though:
We're not splitting elements—we're splitting the tensor along the hidden dimension. Each GPU computes GeLU on its complete slice:
This is valid because GeLU is applied independently to each element.
Attention Block¶
Transformer attention:
Where \(Q = XW^Q\), \(K = XW^K\), \(V = XW^V\).
Multi-head attention is naturally parallelizable:
Each head is independent. With \(h\) heads and \(P\) GPUs (where \(P\) divides \(h\)):
Each GPU computes \(h/P\) heads.
flowchart TB
X["Input X (replicated)"]
subgraph heads["Head Distribution (P=4 GPUs, h=32 heads)"]
direction LR
G0["GPU 0<br/>heads 0-7"]
G1["GPU 1<br/>heads 8-15"]
G2["GPU 2<br/>heads 16-23"]
G3["GPU 3<br/>heads 24-31"]
end
subgraph qkv["Q, K, V Projections (column-parallel, no comm)"]
direction TB
QKV["Each GPU: W_Q, W_K, W_V for local heads"]
end
subgraph attn["Attention (local per head)"]
direction TB
ATT["softmax(QK^T / sqrt(d_k)) V"]
end
subgraph out["Output Projection (row-parallel)"]
direction TB
WO["W_O sharded by rows"]
end
AR["AllReduce"]
Y["Output (replicated)"]
X --> heads
G0 & G1 & G2 & G3 --> qkv --> attn --> out --> AR --> Y
style G0 fill:#3498db,stroke:#2980b9,color:white
style G1 fill:#e74c3c,stroke:#c0392b,color:white
style G2 fill:#2ecc71,stroke:#27ae60,color:white
style G3 fill:#f39c12,stroke:#d68910,color:white
style AR fill:#9b59b6,stroke:#8e44ad,color:white
Communication Count¶
Per Transformer layer:
| Component | Communication |
|---|---|
| Attention Q, K, V projections | None (column-parallel) |
| Attention computation | None (head-parallel) |
| Attention output projection | 1 AllReduce |
| MLP up-projection | None (column-parallel) |
| GeLU | None (local) |
| MLP down-projection | 1 AllReduce |
Total: 2 AllReduce operations in the forward pass (4 total including backward)
Communication Analysis¶
Volume per Layer¶
For a Transformer with:
- Hidden dimension \(d\)
- Tensor parallel degree \(P\)
- Sequence length \(s\)
- Batch size \(b\)
Each AllReduce synchronizes the activation tensor:
Two AllReduces per layer:
For FP16 and large \(P\):
Time per Layer¶
Using α-β model:
For large tensors (bandwidth-dominated):
Compute-Communication Ratio¶
Compute per layer (forward only):
For \(s \ll d\) (typical for LLMs):
Ratio:
For H100 (\(F = 989 \times 10^{12}\) FLOP/s dense BF16), NVLink (\(\beta = 900\) GB/s), \(d = 8192\):
For \(P = 8\): \(R \approx 0.23\) — communication-bound!
This is why tensor parallelism is typically limited to within a node.
Layer Normalization¶
The Challenge¶
LayerNorm:
where:
Computing \(\mu\) and \(\sigma\) requires the full hidden dimension—can't be done on sharded activations.
Solutions¶
Option 1: Pre-LayerNorm on Replicated Activations
Apply LayerNorm before entering TP region:
This is the Megatron-LM approach.
Option 2: Parallel LayerNorm
Compute partial statistics on each shard, AllReduce to get global statistics.
GPU \(i\) computes:
AllReduce to get:
Then normalize locally with global statistics.
Cost: Additional AllReduce per LayerNorm. Usually avoided.
Dropout in Tensor Parallelism¶
The Challenge¶
Dropout applies a random mask:
where \(M\) is a binary mask with \(P(M_i = 1) = 1 - p\).
For reproducibility, the mask must be the same across GPUs for replicated activations.
Solution: Synchronized RNG¶
class TPDropout(nn.Module):
def __init__(self, p, tp_group):
self.p = p
self.tp_group = tp_group
def forward(self, x):
if self.training:
# Synchronize RNG state across TP group
seed = torch.randint(0, 2**32, (1,))
dist.broadcast(seed, src=0, group=self.tp_group)
# Generate identical mask on all GPUs
gen = torch.Generator().manual_seed(seed.item())
mask = torch.bernoulli(torch.ones_like(x) * (1 - self.p),
generator=gen)
return x * mask / (1 - self.p)
return x
Implementation¶
Column-Parallel Linear¶
import torch
import torch.nn as nn
import torch.distributed as dist
class ColumnParallelLinear(nn.Module):
"""Linear layer with column-wise weight partitioning."""
def __init__(self, in_features, out_features, tp_group, bias=True):
super().__init__()
self.tp_group = tp_group
self.tp_size = dist.get_world_size(tp_group)
self.tp_rank = dist.get_rank(tp_group)
# Each GPU holds out_features / tp_size columns
self.out_features_per_gpu = out_features // self.tp_size
# Local weight shard
self.weight = nn.Parameter(
torch.empty(self.out_features_per_gpu, in_features)
)
if bias:
self.bias = nn.Parameter(
torch.empty(self.out_features_per_gpu)
)
else:
self.bias = None
self._init_weights()
def _init_weights(self):
nn.init.kaiming_uniform_(self.weight, a=5**0.5)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, x):
# x: [batch, seq, in_features] - replicated
# out: [batch, seq, out_features_per_gpu] - column-sharded
out = torch.matmul(x, self.weight.t())
if self.bias is not None:
out = out + self.bias
return out
Row-Parallel Linear¶
class RowParallelLinear(nn.Module):
"""Linear layer with row-wise weight partitioning."""
def __init__(self, in_features, out_features, tp_group, bias=True):
super().__init__()
self.tp_group = tp_group
self.tp_size = dist.get_world_size(tp_group)
self.tp_rank = dist.get_rank(tp_group)
# Each GPU holds in_features / tp_size rows
self.in_features_per_gpu = in_features // self.tp_size
# Local weight shard
self.weight = nn.Parameter(
torch.empty(out_features, self.in_features_per_gpu)
)
if bias:
# Bias is replicated; add after AllReduce on all ranks
self.bias = nn.Parameter(torch.empty(out_features))
else:
self.bias = None
self._init_weights()
def _init_weights(self):
nn.init.kaiming_uniform_(self.weight, a=5**0.5)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, x):
# x: [batch, seq, in_features_per_gpu] - column-sharded
# Compute partial result
partial = torch.matmul(x, self.weight.t())
# AllReduce to sum partial results
dist.all_reduce(partial, op=dist.ReduceOp.SUM, group=self.tp_group)
# Add bias (same on all GPUs)
if self.bias is not None:
partial = partial + self.bias
return partial
Megatron MLP Block¶
class TensorParallelMLP(nn.Module):
"""MLP block with Megatron-style tensor parallelism."""
def __init__(self, hidden_size, ffn_hidden_size, tp_group):
super().__init__()
self.tp_group = tp_group
# Up projection: column-parallel (output is sharded)
self.up_proj = ColumnParallelLinear(
hidden_size, ffn_hidden_size, tp_group, bias=True
)
# Down projection: row-parallel (input is sharded, output replicated)
self.down_proj = RowParallelLinear(
ffn_hidden_size, hidden_size, tp_group, bias=True
)
self.activation = nn.GELU()
def forward(self, x):
# x: [batch, seq, hidden] - replicated
x = self.up_proj(x) # [batch, seq, ffn/TP] - sharded
x = self.activation(x) # Local GELU
x = self.down_proj(x) # [batch, seq, hidden] - replicated
return x
Tensor Parallel Attention¶
class TensorParallelAttention(nn.Module):
"""Multi-head attention with tensor parallelism."""
def __init__(self, hidden_size, num_heads, tp_group):
super().__init__()
self.tp_group = tp_group
self.tp_size = dist.get_world_size(tp_group)
self.hidden_size = hidden_size
self.num_heads = num_heads
self.num_heads_per_gpu = num_heads // self.tp_size
self.head_dim = hidden_size // num_heads
# Q, K, V projections: column-parallel
self.qkv_proj = ColumnParallelLinear(
hidden_size, 3 * hidden_size, tp_group, bias=True
)
# Output projection: row-parallel
self.out_proj = RowParallelLinear(
hidden_size, hidden_size, tp_group, bias=True
)
def forward(self, x, mask=None):
batch, seq, _ = x.shape
# QKV projection (column-parallel, no comm)
qkv = self.qkv_proj(x) # [batch, seq, 3 * hidden / TP]
# Reshape to separate Q, K, V for local heads
qkv = qkv.view(batch, seq, 3, self.num_heads_per_gpu, self.head_dim)
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
# Attention computation (local)
q = q.transpose(1, 2) # [batch, heads_per_gpu, seq, head_dim]
k = k.transpose(1, 2)
v = v.transpose(1, 2)
attn_weights = torch.matmul(q, k.transpose(-2, -1))
attn_weights = attn_weights / (self.head_dim ** 0.5)
if mask is not None:
attn_weights = attn_weights + mask
attn_weights = torch.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(1, 2).reshape(
batch, seq, self.num_heads_per_gpu * self.head_dim
)
# Output projection (row-parallel, AllReduce)
output = self.out_proj(attn_output)
return output
Backward Pass Analysis¶
Gradient Flow¶
For column-parallel \(Y = XW\) where \(W\) is column-sharded:
Forward: No communication Backward:
Given \(\frac{\partial L}{\partial Y}\) (sharded same as \(Y\)):
Local computation—\(X\) is replicated, \(\frac{\partial L}{\partial Y_i}\) is local.
Requires AllReduce to sum contributions from all shards.
Backward Communication¶
For one Megatron-style layer:
| Forward | Backward |
|---|---|
| Column-parallel MLP: 0 comm | AllReduce (for \(\partial L/\partial X\)) |
| Row-parallel MLP: AllReduce | 0 comm (input grad is local) |
| Column-parallel Attention: 0 comm | AllReduce |
| Row-parallel Attention output: AllReduce | 0 comm |
Total per layer: 4 AllReduce (2 forward + 2 backward)
Scaling Limits¶
Maximum Tensor Parallel Degree¶
The tensor parallel degree \(P\) is limited by:
- Head divisibility: \(P\) must divide number of attention heads
- Hidden dimension divisibility: \(P\) must divide hidden dimension
- Communication overhead: NVLink bandwidth limits
Practical limits:
| Node Type | Max TP | Reason |
|---|---|---|
| 8× A100 NVLink | 8 | Full node, high bandwidth |
| 8× H100 NVLink | 8 | Full node, higher bandwidth |
| 2 nodes (16 GPUs) | 16 | Inter-node communication expensive |
When to Use Tensor Parallelism¶
Use TP when:
- Model doesn't fit in single GPU memory
- Within a single node (fast NVLink)
- Need to reduce per-GPU memory for activations
Don't use TP when:
- Model fits comfortably (DP is simpler)
- Crossing node boundaries (use PP instead)
- Very small batch sizes (latency-dominated)
Exercises¶
- Bias partitioning: In column-parallel linear \(Y = XW + b\), show that partitioning \(b\) along with \(W\) columns gives correct results. In row-parallel, why must bias be added after AllReduce?
Solution
Part 1: Column-parallel bias partitioning
For column-parallel: \(Y = XW + b\) where \(W = [W_1 | W_2 | \cdots | W_P]\)
The output \(Y = [Y_1 | Y_2 | \cdots | Y_P]\) is column-partitioned.
Partitioning bias similarly: \(b = [b_1 | b_2 | \cdots | b_P]\)
On GPU \(i\):
Why this is correct:
The full computation is:
Since bias addition is element-wise and the column partitioning aligns:
Part 2: Row-parallel bias handling
For row-parallel: \(Y = XW + b\) where \(W = \begin{bmatrix} W_1 \\ W_2 \\ \vdots \\ W_P \end{bmatrix}\)
On GPU \(i\):
The full result requires: \(Y = \sum_{i=1}^{P} Z_i + b\)
Why bias must come after AllReduce:
If each GPU added \(b\) before AllReduce:
After AllReduce:
This is wrong — the bias is multiplied by \(P\)!
Correct approach options:
| Approach | Implementation |
|---|---|
| Add after AllReduce | \(Y = \text{AllReduce}(\sum_i Z_i) + b\) |
| Add scaled bias | Each GPU adds \(b/P\), then AllReduce |
| Add on one rank | Only rank 0 adds \(b\), then AllReduce |
The first option is simplest and most common.
- Communication derivation: For a Transformer with \(d = 4096\), \(s = 2048\), \(b = 4\), TP = 8, calculate the exact bytes transferred per layer in forward pass. Use FP16.
Solution
Given:
- Hidden dimension: \(d = 4096\)
- Sequence length: \(s = 2048\)
- Batch size: \(b = 4\)
- Tensor parallel degree: \(P = 8\)
- Data type: FP16 (2 bytes per element)
AllReduce volume formula (ring algorithm):
Tensor size per AllReduce:
Each AllReduce synchronizes the activation tensor:
Per AllReduce volume:
Forward pass communication per layer:
| Component | AllReduces | Volume |
|---|---|---|
| Attention output projection | 1 | 112 MB |
| MLP down projection | 1 | 112 MB |
| Total | 2 | 224 MB |
Per-GPU bandwidth calculation:
In ring AllReduce, each GPU sends and receives:
With 2 AllReduces: \(224\) MB per GPU per layer.
For the full model (e.g., 80 layers):
- Compute-communication ratio: For the same model, compute \(R\) assuming H100 with 989 TFLOP/s (dense BF16) and NVLink at 900 GB/s. Is the layer compute-bound or communication-bound?
Solution
Given (from previous problem):
- \(d = 4096\), \(s = 2048\), \(b = 4\), \(P = 8\)
- H100: \(F = 989 \times 10^{12}\) FLOP/s (dense BF16 Tensor Cores)
- NVLink: \(\beta = 900 \times 10^9\) bytes/s
Compute per layer (forward pass):
For a Transformer layer:
With \(d = 4096\), \(s = 2048\): - First term: \(4 \times 4 \times 2048 \times 4096^2 = 5.50 \times 10^{11}\) FLOPs - Second term: \(2 \times 4 \times 2048^2 \times 4096 = 1.37 \times 10^{11}\) FLOPs
Compute per GPU (with TP=8):
Compute time per GPU:
Communication time:
Volume per layer: 224 MB = \(2.24 \times 10^8\) bytes
Using bandwidth-dominated model (2 AllReduces):
Compute-communication ratio:
Conclusion:
The layer spends \(\sim 3\times\) more time communicating than computing!
| Metric | Value |
|---|---|
| Compute time | 42.9 μs |
| Communication time | 249 μs |
| Ratio \(R\) | 0.17 |
| Regime | Communication-bound |
Implications:
- TP=8 is too aggressive for this configuration
- Consider TP=4 (doubles \(R\) to ~0.34) or TP=2 (\(R\) ~0.68)
- Or use sequence parallelism to overlap communication
- LayerNorm parallelism: Derive the formulas for computing global mean and variance from partial statistics on sharded tensors. What are the AllReduce volumes needed?
Solution
LayerNorm on sharded hidden dimension:
Given activation \(X\) with hidden dimension \(d\) sharded across \(P\) GPUs.
GPU \(i\) holds \(X_i\) of size \(d/P\).
Global mean:
Each GPU computes local sum:
AllReduce to get global sum:
Global mean:
Global variance:
Each GPU computes local sum of squares:
AllReduce to get global sum of squares:
Global variance:
Complete algorithm:
def parallel_layernorm(x_shard, tp_group, d, gamma, beta):
P = dist.get_world_size(tp_group)
# Local statistics
local_sum = x_shard.sum(dim=-1, keepdim=True)
local_sq_sum = (x_shard ** 2).sum(dim=-1, keepdim=True)
# AllReduce statistics
stats = torch.cat([local_sum, local_sq_sum], dim=-1)
dist.all_reduce(stats, group=tp_group)
global_sum, global_sq_sum = stats.split(1, dim=-1)
# Global mean and variance
mu = global_sum / d
var = global_sq_sum / d - mu ** 2
std = torch.sqrt(var + 1e-6)
# Normalize locally
x_norm = (x_shard - mu) / std
# Apply local gamma/beta shards
return x_norm * gamma_shard + beta_shard
AllReduce volumes:
Per LayerNorm, we AllReduce 2 scalars per token position:
| Statistic | Shape | Size (FP32) |
|---|---|---|
| Sum | \([b, s, 1]\) | \(4bs\) bytes |
| Sum of squares | \([b, s, 1]\) | \(4bs\) bytes |
| Total | \(8bs\) bytes |
For \(b=4\), \(s=2048\):
Comparison to activation AllReduce:
| AllReduce Type | Volume |
|---|---|
| Activation (row-parallel) | \(8bsd = 128\) MB |
| LayerNorm statistics | \(8bs = 64\) KB |
| Ratio | 2000× |
LayerNorm AllReduce is negligible compared to activation AllReduce. However, Megatron-LM still avoids it by applying LayerNorm to replicated activations before tensor parallelism begins.
- GeLU placement: Why must GeLU come between column-parallel and row-parallel layers, not before or after both? What would go wrong if GeLU came after row-parallel?
Solution
Correct placement: Column-parallel → GeLU → Row-parallel
Why GeLU between column-parallel and row-parallel works:
After column-parallel: - Output \(Y\) is column-sharded: \(Y = [Y_0 | Y_1 | \cdots | Y_{P-1}]\) - Each \(Y_i\) contains complete elements (not partial sums)
GeLU is element-wise:
Each GPU applies GeLU to its local shard independently. No communication needed!
What if GeLU came BEFORE column-parallel?
This actually works fine—GeLU on replicated input is just local computation. But it doesn't match the standard MLP structure: \(\text{GeLU}(XW_1)W_2 \neq \text{GeLU}(X)W_1W_2\).
What if GeLU came AFTER row-parallel?
This changes the mathematical function:
These are not equivalent! The non-linearity must be between the two linear transforms.
What goes wrong mathematically:
The MLP's expressive power comes from the non-linearity between layers. Without it:
This collapses to a single linear layer. The GeLU must break this composition.
Summary:
| GeLU Placement | Valid? | Issue |
|---|---|---|
| Between ColPar and RowPar | ✓ | Correct |
| Before ColPar | ✗ | Wrong function |
| After RowPar | ✗ | Wrong function |
| Both before and after | ✗ | Wrong function |
Key insight: GeLU placement is constrained by the mathematical structure of the MLP, not by tensor parallelism. TP just happens to work perfectly with the correct placement because column-parallel output is element-complete.
- Attention head constraints: With 32 attention heads and TP = 6, what goes wrong? How would you handle this case?
Solution
The problem:
With \(h = 32\) heads and \(P = 6\) GPUs:
Heads don't divide evenly! Each GPU can't have the same number of heads.
What goes wrong in practice:
- Unequal memory: Some GPUs would have 5 heads, others 6
- Unequal compute: Load imbalance across GPUs
- Synchronization: AllReduce with different tensor sizes is problematic
- Implementation complexity: Padding or special handling needed
Solutions:
Option 1: Choose compatible TP degree
Use \(P \in \{1, 2, 4, 8, 16, 32\}\) (divisors of 32):
| TP Degree | Heads per GPU | Compatible? |
|---|---|---|
| 2 | 16 | ✓ |
| 4 | 8 | ✓ |
| 6 | 5.33 | ✗ |
| 8 | 4 | ✓ |
Recommendation: Use TP=4 or TP=8 instead of TP=6.
Option 2: Pad attention heads
Add dummy heads to make divisible:
Add 4 dummy heads (set to zero or mask out).
Downsides: - Wasted compute (12.5% overhead) - Memory overhead for dummy heads - Complexity in implementation
Option 3: Grouped Query Attention (GQA)
Modern architectures like Llama 2 use GQA with fewer KV heads: - Query heads: 32 - KV heads: 8 (shared across groups)
With \(h_{kv} = 8\), TP=8 works: 1 KV head per GPU.
Option 4: Hybrid parallelism
Use TP within partial group: - GPUs 0-3: TP=4 on heads 0-15 - GPUs 4-5: TP=2 on heads 16-31 (16 heads, 8 each)
This is complex and rarely used.
Best practice:
Common configurations:
| Model | Heads | Valid TP |
|---|---|---|
| GPT-3 | 96 | 1, 2, 3, 4, 6, 8, 12, 16, 24, 32, 48 |
| Llama 70B | 64 | 1, 2, 4, 8, 16, 32, 64 |
| Llama 7B | 32 | 1, 2, 4, 8, 16, 32 |
- Backward analysis: For a column-parallel linear layer \(Y = XW\), derive the gradient formulas and show that \(\nabla_X L\) requires AllReduce while \(\nabla_W L\) does not.
Solution
Setup:
Column-parallel linear: \(Y = XW\) where: - \(X \in \mathbb{R}^{m \times d_{in}}\) (replicated across all GPUs) - \(W = [W_0 | W_1 | \cdots | W_{P-1}]\) with \(W_i \in \mathbb{R}^{d_{in} \times (d_{out}/P)}\) - \(Y = [Y_0 | Y_1 | \cdots | Y_{P-1}]\) with \(Y_i = XW_i\)
Given: \(\frac{\partial L}{\partial Y}\) (sharded same as \(Y\))
GPU \(i\) has: \(\frac{\partial L}{\partial Y_i} \in \mathbb{R}^{m \times (d_{out}/P)}\)
Gradient with respect to \(W_i\) (local weight shard):
Using chain rule:
Since \(Y_i = XW_i\):
Analysis: - \(X\) is replicated (all GPUs have it) - \(\frac{\partial L}{\partial Y_i}\) is local to GPU \(i\) - Result: \(\frac{\partial L}{\partial W_i} \in \mathbb{R}^{d_{in} \times (d_{out}/P)}\)
Each GPU computes its local gradient independently.
Gradient with respect to \(X\):
Using chain rule:
Since \(Y = XW\):
Expanding with column partitioning:
Analysis:
GPU \(i\) can compute: \(\frac{\partial L}{\partial Y_i} W_i^T \in \mathbb{R}^{m \times d_{in}}\)
But the full gradient is the sum over all GPUs:
Summary:
| Gradient | Formula | Communication |
|---|---|---|
| \(\nabla_{W_i} L\) | \(X^T \frac{\partial L}{\partial Y_i}\) | None (local) |
| \(\nabla_X L\) | \(\sum_i \frac{\partial L}{\partial Y_i} W_i^T\) | AllReduce (sum) |
Communication pattern:
Forward: X (replicated) → [Y₀|Y₁|...|Yₚ₋₁] (sharded) — no comm
Backward: [∂L/∂Y₀|...|∂L/∂Yₚ₋₁] (sharded)
↓
GPU i: ∂L/∂Yᵢ · Wᵢᵀ (local partial)
↓
AllReduce(sum) → ∂L/∂X (replicated)
This is why Megatron pairs column-parallel (no forward comm) with row-parallel (no backward comm for input grad), achieving balanced communication.
Knobs and Trade-offs¶
| Knob | Primary Effect | Cost |
|---|---|---|
| Tensor-parallel degree (T) | Splits large matmuls | More activation AllReduce/AllGather |
| Intra-node placement | Lower latency and higher BW | Constrains topology and scaling |
| Shard layout (row/col) | Changes comm phase | Requires matching layer ordering |
| Fusion (e.g., QKV) | Fewer collectives | More complex kernels and scheduling |
Key Takeaways¶
-
Linearity enables parallelism: \(f(X_1 + X_2) = f(X_1) + f(X_2)\) allows independent computation.
-
Column-parallel is communication-free forward: Split output dimension, no AllReduce needed.
-
Row-parallel requires AllReduce: Split input dimension, must sum partial results.
-
Megatron pattern chains them: Column-parallel → non-linearity → row-parallel → AllReduce.
-
2 AllReduce per layer: One for attention output, one for MLP output.
-
Non-linear ops need care: GeLU works on sharded tensors; LayerNorm doesn't.
-
Tensor parallelism is communication-intensive: Best within NVLink-connected nodes.
-
Backward doubles communication: 4 AllReduce per layer total (forward + backward).