Architecture-Aware Efficiency
The most impactful efficiency gains often come not from systems tricks but from architectural choices that reduce communication, memory, or compute at their source. This chapter covers four architectural innovations that change the distributed training landscape — understanding them here prepares us for the case studies in Part VIII.
The Question: Standard Multi-Head Attention stores separate K, V projections per head. With 128 heads and 128K sequence length, KV cache alone can exceed 100 GB. Can we redesign the attention mechanism itself to reduce this by 10–50×, without losing quality?
Chapter Map
Prerequisites: Chapter 15 (tensor parallelism), Chapter 17 (sequence parallelism), Chapter 16 (pipeline schedules)
Key insight: Architectural choices (attention pattern, KV sharing, pipeline scheduling) are the highest-leverage efficiency knobs because they reduce the fundamental work, not just how it's distributed.
Grouped-Query Attention (GQA)¶
The KV Cache Problem¶
In standard Multi-Head Attention (MHA), each of \(A\) attention heads has its own Key and Value projections:
where \(d_h = H/A\) is the per-head dimension and \(s\) is bytes per element. For a 70B model with \(A=64\), \(H=8192\), BF16:
For 128K context: \(32 \text{ KB} \times 128{,}000 = 4.1 \text{ GB per layer}\). With 80 layers: 328 GB — exceeding a single GPU.
GQA: Sharing KV Heads¶
Grouped-Query Attention (Ainslie et al., 2023) groups \(A\) query heads into \(g\) groups, each sharing a single KV head:
The reduction factor vs MHA is \(A/g\):
| Variant | KV Heads (\(g\)) | Reduction vs MHA | Quality Impact |
|---|---|---|---|
| MHA | \(g = A\) (e.g., 64) | 1× (baseline) | — |
| GQA-8 | \(g = 8\) | 8× | Minimal |
| GQA-1 (MQA) | \(g = 1\) | 64× | Slight degradation |
LLaMA 2 70B and LLaMA 3 use GQA-8 (\(g = 8\)). This reduces KV cache by 8× with negligible quality loss.
Distributed Training Implications¶
| Aspect | MHA | GQA-8 |
|---|---|---|
| KV cache memory | \(2AHs\) per token | \(2gHs/A \cdot A = 2gd_h s\) per token |
| TP communication (KV) | AllGather \(A\) KV heads | AllGather \(g\) KV heads (8× less) |
| Sequence parallelism memory | \(O(S \cdot A \cdot d_h)\) | \(O(S \cdot g \cdot d_h)\) |
| Inference batch capacity | Limited by KV cache | 8× more sequences per GPU |
Practice
GQA is now the default for all major LLMs. If designing a new model, start with GQA-8 unless you have a specific reason for full MHA.
Sliding Window Attention (SWA)¶
From Global to Local Attention¶
Standard attention computes all-pairs interactions: \(O(S^2)\) per layer. Sliding Window Attention restricts each token to attend only to the \(w\) nearest tokens:
Memory: \(O(S \times w)\) instead of \(O(S^2)\)
Effective context: With \(L\) layers, information propagates \(L \times w\) tokens — a 32-layer model with \(w=4096\) has effective context of 131K tokens.
Combined with GQA¶
SWA and GQA compose multiplicatively:
This is independent of sequence length \(S\) — the cache is bounded by the window size.
| Configuration | KV Cache per Layer (BF16, \(H=4096\), \(S=32K\)) |
|---|---|
| MHA, global attention | \(2 \times 32 \times 128 \times 32K \times 2 = 524\) MB |
| GQA-8, global | \(2 \times 8 \times 128 \times 32K \times 2 = 131\) MB |
| GQA-8, SWA (\(w=4096\)) | \(2 \times 8 \times 128 \times 4096 \times 2 = 16.8\) MB |
31× reduction from combining both techniques.
Distributed Training Implications¶
- Sequence parallelism: SWA makes Ring Attention cheaper — each chunk only needs \(w\) tokens of context from its neighbor, not the full preceding sequence
- Pipeline parallelism: Smaller activation tensors at stage boundaries
- Memory: Can increase batch size with the freed memory, improving MFU
Mistral 7B demonstrated that SWA + GQA can match models 2× their size — see Chapter 36 for the full analysis.
Multi-Head Latent Attention (MLA)¶
Beyond GQA: Compressing the Latent Space¶
GQA reduces KV heads. MLA (DeepSeek-V2/V3) goes further by compressing the KV representation into a low-dimensional latent space:
where:
- \(h_t \in \mathbb{R}^H\) is the hidden state
- \(c_t \in \mathbb{R}^{d_c}\) is the compressed latent (\(d_c \ll H\))
- \(W_{\text{DKV}} \in \mathbb{R}^{d_c \times H}\) compresses, \(W_{\text{UK}}, W_{\text{UV}}\) decompress
KV cache stores only \(c_t\), not full K, V:
With \(d_c = 512\) and \(H = 7168\) (DeepSeek-V3): the latent dimension is \(14\times\) smaller than the hidden dimension. The actual KV cache compression vs MHA is even larger—about \(65\times\)—because MHA stores \(2 \times A \times d_h\) values per token while MLA stores only \(d_c\) (see comparison table below).
Comparison¶
| Method | KV Cache per Token per Layer (BF16) | Relative to MHA |
|---|---|---|
| MHA (\(A=128\), \(d_h=128\)) | \(2 \times 128 \times 128 \times 2 = 65.5\) KB | 1× |
| GQA-8 | \(2 \times 8 \times 128 \times 2 = 4.1\) KB | 0.063× |
| MLA (\(d_c = 512\)) | \(512 \times 2 = 1.0\) KB | 0.015× |
MLA achieves ~65× compression vs MHA. This enables DeepSeek-V3 to serve 671B parameters with practical KV cache sizes.
Training Implications¶
During training, the decompression \(K_t = W_{\text{UK}} c_t\) can be absorbed into the query projection (a matrix algebra trick), avoiding explicit decompression. This means:
- Forward pass: Same FLOPs as standard attention (decompression is fused)
- Memory: Only \(c_t\) stored for backward pass → dramatic activation memory savings
- Communication: Smaller activation tensors for TP and SP
See Chapter 35 for DeepSeek-V3's full use of MLA.
DualPipe: Bidirectional Pipeline Scheduling¶
The Bubble Problem Revisited¶
Standard 1F1B pipeline parallelism has bubble fraction \((p-1)/(m+p-1)\). For \(p=16\), \(m=32\): bubble = 32%. Zero-Bubble schedules (ZB-H1) address this but add memory pressure.
DualPipe: Two Pipelines, One Pass¶
DualPipe (DeepSeek-V3) splits micro-batches into two streams flowing in opposite directions through the pipeline:
Stream A: Stage 0 → Stage 1 → ... → Stage P-1 (forward direction)
Stream B: Stage P-1 → Stage P-2 → ... → Stage 0 (reverse direction)
Each stage alternates between processing micro-batches from both streams. While one stream is in the communication phase (sending activations), the other is in the compute phase.
Bubble Reduction¶
Compare to 1F1B: \((p-1)/(m+p-1)\). For \(p=16\), \(m=32\):
- 1F1B: \(15/47 = 31.9\%\)
- DualPipe: \(15/64 = 23.4\%\)
The improvement grows with \(m\): as \(m \to \infty\), 1F1B bubble \(\to 0\) slowly, while DualPipe bubble \(\to 0\) twice as fast.
Key Insight: Overlapping Communication and Compute Across Streams¶
The real power of DualPipe is that while stream A waits for an activation transfer, stream B can compute — and vice versa. This converts pipeline communication time into productive computation, effectively hiding the inter-stage latency.
Constraint: Requires sufficient memory to hold activations for both streams simultaneously. Memory overhead is approximately \(2\times\) compared to standard 1F1B.
See Chapter 35 for DualPipe's role in DeepSeek-V3's training.
Summary: Architecture as Efficiency¶
| Innovation | Primary Saving | Reduction Factor | Adopted By |
|---|---|---|---|
| GQA | KV cache memory | 4–64× | LLaMA ⅔, Mistral, Gemma |
| SWA | Attention memory/compute | \(S/w\) × | Mistral, Mixtral |
| MLA | KV cache + activation memory | ~65× vs MHA (\(H/d_c\) = 14× latent compression) | DeepSeek-V2/V3 |
| DualPipe | Pipeline bubble time | ~2× | DeepSeek-V3 |
These are not systems optimizations — they are architectural choices that change the fundamental resource requirements. Understanding them is essential for the case studies that follow.
Key Takeaways¶
-
GQA is the new default: Sharing KV heads across query groups reduces cache 8× with negligible quality loss.
-
SWA bounds memory independent of sequence length: Attention memory becomes \(O(S \times w)\) instead of \(O(S^2)\).
-
MLA compresses the KV bottleneck further: Low-rank latent representations achieve 65× KV cache reduction.
-
DualPipe halves effective bubble fraction: Bidirectional micro-batch streams hide inter-stage communication.
-
Architecture co-designs with distribution: The best efficiency gains come from reducing the work at its source, not just distributing it better.
Exercises¶
- GQA memory savings: A model has \(A=64\) attention heads with \(d_h=128\). Compare the KV cache size per token per layer (in BF16) for: (a) standard MHA, (b) GQA with 8 KV groups, © GQA with 1 KV group (MQA). What are the reduction factors?
Solution
- MHA: \(2 \times 64 \times 128 \times 2 = 32{,}768\) bytes = 32 KB
- GQA-8: \(2 \times 8 \times 128 \times 2 = 4{,}096\) bytes = 4 KB → 8× reduction
- MQA (GQA-1): \(2 \times 1 \times 128 \times 2 = 512\) bytes → 64× reduction
The reduction factor equals \(A/g\) where \(g\) is the number of KV groups.
- SWA effective context: A model has \(L=32\) layers with sliding window size \(w=4{,}096\). (a) What is the maximum effective context? (b) If input length is \(S=128{,}000\), how many layers can "see" the first token from position \(S\)?
Solution
(a) Maximum effective context \(= L \times w = 32 \times 4{,}096 = 131{,}072\) tokens.
(b) The first token is reachable from position \(S\) only if \(S \leq L \times w\). Since \(128{,}000 < 131{,}072\), the first token is reachable. The number of layers needed to propagate information from position \(S\) to position 1 is \(\lceil S/w \rceil = \lceil 128{,}000/4{,}096 \rceil = 32\) layers. So only the final layer can "see" the first token, and only barely.
- DualPipe vs 1F1B: Compare the bubble fraction for a pipeline with \(P=16\) stages and \(m=32\) micro-batches under: (a) standard 1F1B, (b) DualPipe. How many additional micro-batches would 1F1B need to match DualPipe's bubble fraction?
Solution
(a) 1F1B bubble: \(\frac{P-1}{m+P-1} = \frac{15}{32+15} = \frac{15}{47} \approx 31.9\%\)
(b) DualPipe bubble: \(\approx \frac{P-1}{2m} = \frac{15}{64} \approx 23.4\%\)
To match DualPipe with 1F1B: \(\frac{15}{m'+15} = \frac{15}{64}\), so \(m'+15 = 64\), giving \(m' = 49\). 1F1B would need 49 micro-batches (vs 32 with DualPipe) to achieve the same bubble fraction—a 53% increase.