Section 5.5: Multi-Head Attention — Multiple Perspectives¶
Reading time: 18 minutes | Difficulty: ★★★☆☆
A single attention mechanism can only focus on one type of relationship at a time. Multi-head attention runs multiple attention operations in parallel, allowing the model to jointly attend to information from different representation subspaces.
The Limitation of Single-Head Attention¶
Consider processing "The cat sat on the mat because it was tired."
Different types of information are relevant:
- Syntactic: "sat" should attend to "cat" (subject-verb)
- Positional: "tired" should attend to nearby words
- Coreference: "it" should attend to "cat" (reference)
A single attention head must compress all these relationships into one set of weights. It can't simultaneously:
- Pay maximum attention to the subject AND
- Pay attention to nearby words AND
- Resolve references
The Multi-Head Solution¶
Instead of one attention mechanism with d-dimensional keys/queries, use h parallel attention heads, each with d/h dimensions:
where: $\(\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)\)$
Each head learns to focus on different aspects of the input.
How It Works¶
Step 1: Project to Multiple Heads¶
For each head i:
- Q_i = \(QW_i^Q\) ∈ \(ℝ^{n × d_k}\)
- K_i = \(KW_i^K\) ∈ \(ℝ^{n × d_k}\)
- V_i = \(VW_i^V\) ∈ \(ℝ^{n × d_v}\)
Where d_k = d_v = d/h typically.
Step 2: Parallel Attention¶
Each head computes attention independently:
Step 3: Concatenate and Project¶
Combine all heads and project back:
Where \(W^O\) ∈ \(ℝ^{hd_v × d}\) projects back to model dimension.
Visual Representation¶
Input X [n × d]
│
├──────────────────────────────────────────────┐
│ │
▼ ▼
Head 1 ... Head h
│ │
├── Q₁ = XW¹_Q Q_h = XW^h_Q ─┤
├── K₁ = XW¹_K K_h = XW^h_K ─┤
├── V₁ = XW¹_V V_h = XW^h_V ─┤
│ │
▼ ▼
Attention(Q₁,K₁,V₁) ... Attention(Q_h,K_h,V_h)
│ │
▼ ▼
[n × d_v] ... [n × d_v]
│ │
└──────────────────┬───────────────────────────┘
│ Concat
▼
[n × hd_v]
│
▼ W^O
[n × d]
│
▼
Output
Worked Example¶
Let's trace through with h=2 heads, d=4, d_k=d_v=2.
Input¶
Head 1 Projections¶
W1_Q = [[1, 0], [0, 1], [0, 0], [0, 0]] # [4 × 2]
W1_K = [[0, 1], [1, 0], [0, 0], [0, 0]]
W1_V = [[1, 0], [0, 0], [1, 0], [0, 0]]
Q1 = X @ W1_Q = [[1, 0], [0, 1], [1, 1]] # [3 × 2]
K1 = X @ W1_K = [[0, 1], [1, 0], [1, 1]]
V1 = X @ W1_V = [[2, 0], [0, 0], [1, 0]]
Head 2 Projections¶
W2_Q = [[0, 0], [0, 0], [1, 0], [0, 1]] # Different subspace!
W2_K = [[0, 0], [0, 0], [0, 1], [1, 0]]
W2_V = [[0, 1], [0, 1], [0, 0], [0, 0]]
Q2 = X @ W2_Q = [[1, 0], [0, 1], [0, 0]] # [3 × 2]
K2 = X @ W2_K = [[0, 1], [1, 0], [0, 0]]
V2 = X @ W2_V = [[0, 1], [0, 1], [0, 0]]
Compute Attention for Each Head¶
Head 1: Focuses on first two dimensions of input
Attention scores (Q1 @ K1.T / sqrt(2)):
[[0.71, 0.00, 1.41],
[0.00, 0.71, 0.71],
[0.71, 0.71, 1.41]]
After softmax:
[[0.30, 0.15, 0.55],
[0.21, 0.37, 0.42],
[0.21, 0.21, 0.58]]
Output1 = attention @ V1
Head 2: Focuses on last two dimensions
Concatenate and Project¶
What Different Heads Learn¶
Research on trained transformers reveals specialized heads:
Syntactic Heads¶
"The cat that I saw yesterday sat"
Head focusing on subject-verb:
sat → cat: 0.65 (main subject)
sat → I: 0.10 (not the subject of "sat")
sat → saw: 0.05
Positional Heads¶
Copy/Induction Heads¶
Rare Word Heads¶
Some heads specialize in attending to rare/important tokens
like proper nouns, numbers, or technical terms.
Connection to Modern LLMs
In GPT-2 and similar models, researchers found:
- Induction heads (copy patterns): Emerge in layer 2+ and are crucial for in-context learning
- Previous token heads: Simple but important for local coherence
- Backup heads: Redundant heads that provide robustness
The model learns to allocate heads to different linguistic tasks automatically!
Parameter Analysis¶
For multi-head attention with:
- Model dimension: d
- Number of heads: h
- Head dimension: d_k = d_v = d/h
Per-head parameters:
- \(W_i^Q\): d × d_k = d × d/h = d²/h
- \(W_i^K\): d × d_k = d²/h
- \(W_i^V\): d × d_v = d²/h
Total for all heads: h × 3 × d²/h = 3d²
Output projection: \(W^O\): hd_v × d = d × d = d²
Grand total: 3d² + d² = 4d²
This is the same as having separate Q, K, V, O projections in single-head attention with dimension d. Multi-head is essentially a particular factorization.
Why Not Just Use More Parameters?¶
Question: Why h heads of dimension d/h instead of 1 head of dimension d?
Computational Efficiency¶
Same parameter count, but:
- Each head operates in lower dimension
- All heads compute in parallel
- Similar total compute
Representational Power¶
Different heads can learn:
- Orthogonal attention patterns
- Specialized roles
- Complementary information
A single head would have to encode everything in one pattern.
Empirical Evidence¶
Ablation on WMT translation:
Heads | BLEU Score
------|-----------
1 | 25.8
2 | 27.1
4 | 27.5
8 | 28.0 (original Transformer)
16 | 28.0 (no improvement)
More heads help up to a point, then saturate.
Implementation¶
import numpy as np
def multi_head_attention(X, W_Qs, W_Ks, W_Vs, W_O):
"""
Multi-head attention mechanism.
Args:
X: Input [n, d]
W_Qs: List of query projections [h × (d, d_k)]
W_Ks: List of key projections [h × (d, d_k)]
W_Vs: List of value projections [h × (d, d_v)]
W_O: Output projection [h*d_v, d]
Returns:
Output [n, d], list of attention weights [h × (n, n)]
"""
h = len(W_Qs)
heads = []
attention_weights = []
for i in range(h):
# Project to this head's subspace
Q_i = X @ W_Qs[i]
K_i = X @ W_Ks[i]
V_i = X @ W_Vs[i]
d_k = Q_i.shape[-1]
# Scaled dot-product attention
scores = Q_i @ K_i.T / np.sqrt(d_k)
# Softmax
scores_max = scores.max(axis=-1, keepdims=True)
exp_scores = np.exp(scores - scores_max)
weights = exp_scores / exp_scores.sum(axis=-1, keepdims=True)
attention_weights.append(weights)
# Weighted sum of values
head_output = weights @ V_i
heads.append(head_output)
# Concatenate all heads
concat = np.concatenate(heads, axis=-1) # [n, h*d_v]
# Final projection
output = concat @ W_O # [n, d]
return output, attention_weights
class MultiHeadAttention:
"""Multi-head attention layer with learned parameters."""
def __init__(self, d_model, n_heads):
"""
Initialize multi-head attention.
Args:
d_model: Model dimension
n_heads: Number of attention heads
"""
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.d_v = d_model // n_heads
# Initialize projection matrices
scale = np.sqrt(2.0 / (d_model + self.d_k))
self.W_Qs = [np.random.randn(d_model, self.d_k) * scale
for _ in range(n_heads)]
self.W_Ks = [np.random.randn(d_model, self.d_k) * scale
for _ in range(n_heads)]
self.W_Vs = [np.random.randn(d_model, self.d_v) * scale
for _ in range(n_heads)]
# Output projection
self.W_O = np.random.randn(n_heads * self.d_v, d_model) * scale
def forward(self, X, mask=None):
"""
Forward pass.
Args:
X: Input [n, d_model]
mask: Optional attention mask [n, n]
Returns:
Output [n, d_model], attention weights [n_heads, n, n]
"""
return multi_head_attention(X, self.W_Qs, self.W_Ks, self.W_Vs, self.W_O)
def parameters(self):
"""Return all parameters as a list."""
params = self.W_Qs + self.W_Ks + self.W_Vs + [self.W_O]
return params
Efficient Implementation: Batched Projections¶
In practice, we batch all head projections together:
class EfficientMultiHeadAttention:
"""Efficient multi-head attention using batched operations."""
def __init__(self, d_model, n_heads):
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
# Single large projections instead of per-head
scale = np.sqrt(2.0 / (2 * d_model))
self.W_QKV = np.random.randn(d_model, 3 * d_model) * scale
self.W_O = np.random.randn(d_model, d_model) * scale
def forward(self, X):
"""Forward pass with efficient batched computation."""
n = X.shape[0]
# Single projection for all Q, K, V
QKV = X @ self.W_QKV # [n, 3*d_model]
# Split into Q, K, V
Q, K, V = np.split(QKV, 3, axis=-1) # Each [n, d_model]
# Reshape into heads: [n, d_model] -> [n, h, d_k] -> [h, n, d_k]
Q = Q.reshape(n, self.n_heads, self.d_k).transpose(1, 0, 2)
K = K.reshape(n, self.n_heads, self.d_k).transpose(1, 0, 2)
V = V.reshape(n, self.n_heads, self.d_k).transpose(1, 0, 2)
# Batched attention: [h, n, d_k] @ [h, d_k, n] -> [h, n, n]
scores = np.einsum('hnd,hmd->hnm', Q, K) / np.sqrt(self.d_k)
# Softmax per head
scores_max = scores.max(axis=-1, keepdims=True)
weights = np.exp(scores - scores_max)
weights = weights / weights.sum(axis=-1, keepdims=True)
# Apply attention: [h, n, n] @ [h, n, d_k] -> [h, n, d_k]
heads = np.einsum('hnm,hmd->hnd', weights, V)
# Reshape back: [h, n, d_k] -> [n, h, d_k] -> [n, d_model]
concat = heads.transpose(1, 0, 2).reshape(n, self.d_model)
# Output projection
output = concat @ self.W_O
return output, weights
Visualizing Multi-Head Attention¶
Sentence: "The cat sat on the mat"
Head 1 (syntactic): Head 2 (positional):
T c s o t m T c s o t m
T [ ░ █ ░ ░ ░ ░ ] T [ █ █ ░ ░ ░ ░ ]
c [ █ ░ ░ ░ ░ ░ ] c [ █ █ █ ░ ░ ░ ]
s [ ░ █ ░ ░ ░ ░ ] s [ ░ █ █ █ ░ ░ ]
o [ ░ ░ █ ░ ░ ░ ] o [ ░ ░ █ █ █ ░ ]
t [ ░ ░ ░ █ ░ ░ ] t [ ░ ░ ░ █ █ █ ]
m [ ░ ░ ░ ░ █ ░ ] m [ ░ ░ ░ ░ █ █ ]
Head 1 learns subject-verb Head 2 learns local context
When to Use How Many Heads¶
| Model Size | Typical Heads | Head Dimension |
|---|---|---|
| Small (d=256) | 4 | 64 |
| Medium (d=512) | 8 | 64 |
| Large (d=1024) | 16 | 64 |
| XL (d=2048) | 32 | 64 |
The head dimension is often kept constant (64) while scaling the number of heads with model size.
Exercises¶
-
Implement multi-head: Write both naive and efficient versions.
-
Visualize heads: Train a small model and plot attention patterns for different heads.
-
Head ablation: What happens if you zero out different heads? Which are important?
-
Head pruning: After training, can you remove heads with minimal performance loss?
-
Specialized heads: Can you design heads that must attend to specific patterns?
Summary¶
| Concept | Definition | Purpose |
|---|---|---|
| Multi-head attention | h parallel attention mechanisms | Multiple representation subspaces |
| Head dimension | d_k = d/h | Reduced per-head computation |
| Concatenation | [head_1; ...; head_h] | Combine all perspectives |
| Output projection | \(W^O\) | Mix head outputs |
Key takeaway: Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. Each head can specialize in different types of relationships (syntactic, positional, semantic), enabling richer modeling of language structure than a single attention mechanism could achieve.
→ Next: Section 5.6: Positional Encoding