viewof seqLength = Inputs.range([256, 65536], {
value: 8192,
step: 256,
label: "Sequence Length (n)",
transform: x => Math.round(x)
})
viewof headDim = Inputs.select([32, 64, 128, 256], {
value: 64,
label: "Head Dimension (d)"
})
viewof tileSize = Inputs.range([16, 256], {
value: 128,
step: 16,
label: "Tile Size",
transform: x => Math.round(x)
})
viewof precision = Inputs.radio(["FP32", "FP16", "FP8"], {
value: "FP16",
label: "Precision"
})16 Investigation: FlashAttention
Deriving the Algorithm That Changed Transformers
Don’t explain FlashAttention. Derive it.
The insight isn’t the algorithm—it’s the question that leads to it.
NoteProperty Spotlight: Associativity
This chapter is a case study in associativity—the first property from our Algebraic Framework.
Softmax-weighted sums can be computed incrementally: \((a ⊕ b) ⊕ c = a ⊕ (b ⊕ c)\). This associative structure is what makes chunking possible. Without it, we’d be forced to materialize the full n×n attention matrix.
The derivation that follows shows how recognizing associativity leads directly to the FlashAttention algorithm.
16.1 The Problem
Attention is the heart of the transformer. But it has a memory problem.
Given queries Q, keys K, and values V (each n × d):
def attention(Q, K, V):
S = Q @ K.T # n×n attention scores
P = softmax(S) # n×n attention weights
O = P @ V # n×d output
return OThe matrices S and P are n × n. For sequence length 32K and batch size 1:
- S: 32,768² × 4 bytes = 4 GB
- P: 32,768² × 4 bytes = 4 GB
- Total intermediate: 8 GB
Per layer. Per head. Per batch element.
An 80GB A100 runs out of memory quickly.
TipInteractive: FlashAttention Memory Explorer
Explore how FlashAttention reduces memory usage compared to standard attention. Adjust the sequence length and tile size to see how memory requirements scale.
Key insight: FlashAttention processes small tiles that fit in SRAM (shared memory), avoiding the O(n²) HBM memory footprint. The total memory is dominated by the O(n·d) output and statistics, not the attention matrix.
16.2 The Question
Here’s the question that leads to FlashAttention:
Must we materialize the full n×n matrix?
The output O is only n × d. We produce n² intermediate values to compute n × d outputs. That seems wasteful.
Let’s trace what we actually need.
16.3 Understanding the Computation
For a single output row \(O_i\) (the output for query \(i\)):
\[O_i = \sum_j P_{ij} V_j\]
where:
\[P_{ij} = \frac{e^{S_{ij}}}{\sum_k e^{S_{ik}}}\]
and:
\[S_{ij} = Q_i \cdot K_j\]
Expanded:
\[O_i = \frac{\sum_j e^{Q_i \cdot K_j} V_j}{\sum_j e^{Q_i \cdot K_j}}\]
This is a weighted sum of V vectors, weighted by softmax of attention scores.
Observation: For row \(i\), we need all K and all V, but we don’t need the other Q rows.
This suggests: process one query row at a time.
16.4 Attempt 1: Row-by-Row Computation
def attention_row_by_row(Q, K, V):
n, d = Q.shape
O = torch.zeros(n, d)
for i in range(n):
# Compute scores for row i
scores = Q[i] @ K.T # Shape: (n,)
# Softmax (with stability shift)
scores_max = scores.max()
exp_scores = torch.exp(scores - scores_max)
softmax_denom = exp_scores.sum()
weights = exp_scores / softmax_denom
# Output for row i
O[i] = weights @ V
return OMemory: O(n) per row, O(n) total. We’ve eliminated the n² memory.
But wait—this is slow. We’re doing n sequential iterations, each reading all of K and V from memory.
Memory traffic: O(n²d) reads (read K and V once per query).
Standard attention also reads K and V once—the difference is it’s doing a large matrix multiply, which is efficient.
Problem: We’ve traded memory for memory bandwidth. That’s not a good trade.
16.5 The Roads Not Taken
Before finding the right path, let’s examine approaches that seem promising but fall short. Understanding why they fail clarifies what we actually need.
16.5.1 Attempt 2: Gradient Checkpointing
A standard trick for memory-constrained training: don’t store intermediate activations; recompute them during backprop.
def attention_checkpointed(Q, K, V):
# Forward: compute and store only O
S = Q @ K.T
P = softmax(S)
O = P @ V
# Don't save S or P—will recompute during backward
return OWhy it fails: Checkpointing reduces activation memory during training, but not working memory during the forward pass. We still materialize the full n² matrix at some point. The memory spike remains.
Gradient checkpointing is orthogonal to FlashAttention—you can use both together. But it doesn’t solve the forward pass memory problem.
16.5.2 Attempt 3: Sparse Attention
If n² is too expensive, why not make attention sparse?
def sparse_attention(Q, K, V, pattern='local'):
if pattern == 'local':
# Only attend to nearby positions
window = 256
# ... compute only within window
elif pattern == 'strided':
# Attend to every k-th position
stride = 64
# ...Approaches like Longformer, BigBird, and Sparse Transformers use fixed sparsity patterns to achieve O(n) or O(n√n) complexity.
Why it fails for us: Sparse attention is an approximation. It changes what the model computes, not just how it computes. For many tasks, full attention significantly outperforms sparse variants.
We want the exact same output as standard attention—just computed more efficiently.
16.5.3 Attempt 4: Low-Rank Approximation
Maybe we can approximate the attention matrix with something lower-rank?
def linear_attention(Q, K, V):
# Approximate: instead of softmax(QK^T)V
# Use: phi(Q) @ (phi(K)^T @ V)
# where phi is a feature map
Q_feat = feature_map(Q) # n×r
K_feat = feature_map(K) # n×r
return Q_feat @ (K_feat.T @ V) # n×r × r×n × n×d = O(nrd)Performers, Linear Transformers, and similar approaches use this trick.
Why it fails for us: Again, this is an approximation. The feature map \(\phi\) doesn’t perfectly reproduce softmax behavior. For some tasks it works well; for others, quality degrades significantly.
We want exact attention.
16.5.4 Attempt 5: Sampling-Based Approximation
What if we sample a subset of key-value pairs?
def sampled_attention(Q, K, V, sample_size=256):
# Randomly sample positions
idx = np.random.choice(n, sample_size, replace=False)
K_sampled = K[idx]
V_sampled = V[idx]
return softmax(Q @ K_sampled.T) @ V_sampledWhy it fails: Sampling introduces variance. Critical tokens might be missed. Quality is unpredictable.
16.5.5 What We Actually Need
Looking at these failures, the requirements become clear:
- Exact computation: Same output as standard attention (bit-for-bit, modulo floating point)
- Bounded memory: Don’t materialize the full n² matrix
- Efficient execution: Can’t just be correct; must be fast
The first attempt (row-by-row) achieved #1 and #2 but failed #3. The approximation methods achieved #2 and #3 but failed #1.
Is there a solution that achieves all three?
16.6 The Key Insight: Tiles, Not Rows
The row-by-row approach processes one query at a time. But GPU operations are efficient with larger tiles.
New question: Can we process a tile of queries against a tile of keys/values, and combine tiles without recomputation?
This is where Chapter 4’s associativity insight becomes crucial.
16.7 The Softmax Challenge
Standard softmax needs the global maximum for numerical stability:
def softmax(x):
x_max = x.max()
exp_x = torch.exp(x - x_max)
return exp_x / exp_x.sum()If we process K/V in tiles, we don’t know the global max when processing the first tile.
The blocking issue: Softmax seems to require the full row before producing output.
Let’s solve this.
16.8 Investigation: Online Softmax
What if the maximum changes as we see more data?
Consider two chunks of scores: \([s_1, s_2, s_3]\) and \([s_4, s_5, s_6]\).
Chunk 1: \(m_1 = \max(s_1, s_2, s_3)\), \(d_1 = \sum_{i=1}^{3} e^{s_i - m_1}\)
Chunk 2: \(m_2 = \max(s_4, s_5, s_6)\), \(d_2 = \sum_{i=4}^{6} e^{s_i - m_2}\)
Global: \(m = \max(m_1, m_2)\), \(d = ?\)
The denominator in chunk 1 was computed relative to \(m_1\). To combine with chunk 2, we need to rescale:
\[d_1' = d_1 \cdot e^{m_1 - m}\] \[d_2' = d_2 \cdot e^{m_2 - m}\] \[d = d_1' + d_2'\]
The correction factor: When the max changes, multiply old sums by \(e^{\text{old\_max} - \text{new\_max}}\).
Let’s verify:
import numpy as np
# Full computation
scores = np.array([1.0, 3.0, 2.0, 5.0, 4.0, 3.5])
m_full = scores.max()
d_full = np.exp(scores - m_full).sum()
print(f"Full: max={m_full}, denom={d_full:.4f}")
# Chunked computation
chunk1 = scores[:3]
chunk2 = scores[3:]
m1 = chunk1.max()
d1 = np.exp(chunk1 - m1).sum()
print(f"Chunk 1: max={m1}, denom={d1:.4f}")
m2 = chunk2.max()
d2 = np.exp(chunk2 - m2).sum()
print(f"Chunk 2: max={m2}, denom={d2:.4f}")
# Combine
m = max(m1, m2)
d = d1 * np.exp(m1 - m) + d2 * np.exp(m2 - m)
print(f"Combined: max={m}, denom={d:.4f}")
# Output:
# Full: max=5.0, denom=3.7398
# Chunk 1: max=3.0, denom=2.1353
# Chunk 2: max=5.0, denom=2.1353
# Combined: max=5.0, denom=3.7398 ✓ Matches!The combination works. The softmax denominator has associative structure with state (max, scaled_sum).
16.9 Extending to Output
The denominator is only half the story. We also need the output:
\[O_i = \frac{\sum_j e^{S_{ij}} V_j}{\sum_j e^{S_{ij}}}\]
The numerator is also a sum of exponentials, but weighted by \(V_j\).
State: (max, sum, numerator) = (m, d, o)
When max changes:
def update_state(old_state, new_scores, new_V):
m_old, d_old, o_old = old_state
m_new = new_scores.max()
exp_new = np.exp(new_scores - m_new)
d_new = exp_new.sum()
o_new = exp_new @ new_V
m = max(m_old, m_new)
# Rescale old state
scale_old = np.exp(m_old - m)
scale_new = np.exp(m_new - m)
d = d_old * scale_old + d_new * scale_new
o = o_old * scale_old + o_new * scale_new
return (m, d, o)
# Final output: o / dLet’s verify this works:
# Test data
np.random.seed(42)
Q = np.random.randn(4, 8) # 4 queries, dim 8
K = np.random.randn(6, 8) # 6 keys
V = np.random.randn(6, 8) # 6 values
# Standard attention (for reference)
def standard_attention(Q, K, V):
S = Q @ K.T
S_max = S.max(axis=1, keepdims=True)
exp_S = np.exp(S - S_max)
P = exp_S / exp_S.sum(axis=1, keepdims=True)
return P @ V
O_standard = standard_attention(Q, K, V)
# Chunked attention
def chunked_attention(Q, K, V, chunk_size=2):
n = Q.shape[0]
n_kv = K.shape[0]
d = V.shape[1]
# Initialize state for each query
m = np.full(n, -np.inf)
s = np.zeros(n)
o = np.zeros((n, d))
for j in range(0, n_kv, chunk_size):
K_chunk = K[j:j+chunk_size]
V_chunk = V[j:j+chunk_size]
# Scores for this chunk
scores = Q @ K_chunk.T # (n, chunk_size)
for i in range(n):
row_scores = scores[i]
row_max = row_scores.max()
if row_max > m[i]:
# Rescale old state
scale = np.exp(m[i] - row_max)
s[i] = s[i] * scale
o[i] = o[i] * scale
m[i] = row_max
# Add new contribution
exp_scores = np.exp(row_scores - m[i])
s[i] += exp_scores.sum()
o[i] += exp_scores @ V_chunk
return o / s[:, None]
O_chunked = chunked_attention(Q, K, V, chunk_size=2)
# Compare
print("Max difference:", np.abs(O_standard - O_chunked).max())
# Output: Max difference: 1.11e-15 (floating point precision)It works. We can compute attention in chunks without materializing the full n×n matrix.
16.10 The Full Algorithm
Now we tile both Q and K/V:
def flash_attention(Q, K, V, tile_size=64):
"""
FlashAttention: tiled attention with O(n) memory.
Tiles Q (rows) and K/V (columns) to fit in SRAM.
"""
n, d = Q.shape
n_tiles = (n + tile_size - 1) // tile_size
# Output and running statistics
O = np.zeros((n, d))
M = np.full(n, -np.inf) # Running max
L = np.zeros(n) # Running sum (ell for "l"og-sum-exp)
# Process K/V in tiles
for j in range(0, n, tile_size):
j_end = min(j + tile_size, n)
K_tile = K[j:j_end]
V_tile = V[j:j_end]
# Process Q in tiles (for parallelism)
for i in range(0, n, tile_size):
i_end = min(i + tile_size, n)
Q_tile = Q[i:i_end]
# Compute attention scores for this tile pair
S_tile = Q_tile @ K_tile.T # (tile, tile) - fits in SRAM!
# Row-wise operations
for row in range(i_end - i):
global_row = i + row
row_scores = S_tile[row]
row_max = row_scores.max()
# New max for this row
new_max = max(M[global_row], row_max)
# Rescale old statistics
scale_old = np.exp(M[global_row] - new_max)
scale_new = np.exp(row_max - new_max)
exp_scores = np.exp(row_scores - row_max)
# Update running state
L[global_row] = L[global_row] * scale_old + exp_scores.sum() * scale_new
O[global_row] = O[global_row] * scale_old + (exp_scores @ V_tile) * scale_new
M[global_row] = new_max
# Normalize
return O / L[:, None]16.10.1 Memory Analysis
Standard attention:
S = Q @ K.T: O(n²) memory
P = softmax: O(n²) memory
O = P @ V: O(nd) memory
FlashAttention:
Q_tile: O(tile × d) memory
K_tile: O(tile × d) memory
V_tile: O(tile × d) memory
S_tile: O(tile²) memory
O, M, L: O(nd) memory
Total: O(tile² + nd)
For tile = 128, d = 64, n = 32768:
Standard: 32768² × 4 = 4 GB
Flash: 128² × 4 + 32768 × 64 × 4 = 65KB + 8MB ≈ 8 MB
Reduction: ~500×
16.10.2 Speed Analysis
Surprisingly, FlashAttention is often faster despite doing more arithmetic.
Why? IO is the bottleneck.
Standard attention: 1. Read Q, K from HBM → Compute S → Write S to HBM 2. Read S from HBM → Softmax → Write P to HBM 3. Read P, V from HBM → Compute O → Write O to HBM
Total HBM accesses: O(n² + n² + n²) = O(3n²)
FlashAttention: 1. Load Q_tile, K_tile, V_tile from HBM to SRAM 2. Compute everything in SRAM 3. Write O_tile to HBM
Total HBM accesses: O(nd) for loading Q, K, V, O once
The reduction in memory traffic often exceeds the extra compute cost.
16.11 The Tiling Strategy
How do we choose tile sizes?
Constraint: Tiles must fit in SRAM.
GPU SRAM (shared memory): ~100-200 KB per SM
A100 shared memory: 164 KB
For FP16:
Q_tile: tile × d × 2 bytes
K_tile: tile × d × 2 bytes
V_tile: tile × d × 2 bytes
S_tile: tile × tile × 2 bytes
Accumulators: tile × d × 4 bytes (FP32 for precision)
For tile = 128, d = 64:
Q_tile: 128 × 64 × 2 = 16 KB
K_tile: 128 × 64 × 2 = 16 KB
V_tile: 128 × 64 × 2 = 16 KB
S_tile: 128 × 128 × 2 = 32 KB
Accumulators: 128 × 64 × 4 = 32 KB
Total: ~112 KB ✓ Fits!
The actual FlashAttention implementation tunes tile sizes per GPU architecture.
16.12 The Backward Pass
Training requires gradients. Can we backpropagate through FlashAttention?
Standard backprop would require storing all intermediate activations—defeating the memory savings.
FlashAttention’s solution: recompute instead of store.
During backward: 1. Reload Q, K, V tiles (from HBM) 2. Recompute S_tile, P_tile (in SRAM) 3. Compute gradients using recomputed values 4. Write gradients to HBM
This trades compute for memory. For large n, the trade is favorable—memory is the bottleneck.
def flash_attention_backward(dO, Q, K, V, O, M, L):
"""
Backward pass, recomputing forward values.
dO: gradient of loss w.r.t. output
O, M, L: saved from forward pass (O(n) memory)
"""
n, d = Q.shape
dQ = np.zeros_like(Q)
dK = np.zeros_like(K)
dV = np.zeros_like(V)
for j in range(0, n, tile_size):
K_tile = K[j:j+tile_size]
V_tile = V[j:j+tile_size]
for i in range(0, n, tile_size):
Q_tile = Q[i:i+tile_size]
dO_tile = dO[i:i+tile_size]
# Recompute forward
S_tile = Q_tile @ K_tile.T
P_tile = softmax_with_stats(S_tile, M[i:i+tile_size], L[i:i+tile_size])
# Backward through P @ V
dV[j:j+tile_size] += P_tile.T @ dO_tile
dP_tile = dO_tile @ V_tile.T
# Backward through softmax
dS_tile = softmax_backward(dP_tile, P_tile)
# Backward through S = Q @ K.T
dQ[i:i+tile_size] += dS_tile @ K_tile
dK[j:j+tile_size] += dS_tile.T @ Q_tile
return dQ, dK, dV16.13 Benchmarks
Real-world performance on an A100:
Sequence length: 2048, head dim: 64
Method Memory (MB) Time (ms)
───────────────────────────────────────────────
Standard attention 1,024 2.1
FlashAttention 8 0.8
Sequence length: 16384, head dim: 64
Method Memory (MB) Time (ms)
───────────────────────────────────────────────
Standard attention OOM —
FlashAttention 64 42
Sequence length: 65536, head dim: 64
Method Memory (MB) Time (ms)
───────────────────────────────────────────────
Standard attention OOM —
FlashAttention 256 680
FlashAttention enables sequences that were previously impossible, while being faster on feasible sequences.
16.14 The Derivation Pattern
Let’s trace how we derived FlashAttention:
Identify the problem: O(n²) memory from materializing attention matrix
Ask the key question: Must we materialize it? What do we actually need?
Discover the structure: Softmax has associative structure via (max, sum) state
Extend the insight: Output accumulation also has this structure with (max, sum, output_sum)
Apply hardware constraints: Tile to fit in SRAM, minimize HBM traffic
Handle the full system: Backward pass recomputes to maintain memory savings
This is the investigation pattern. Not “here’s FlashAttention” but “how would you find FlashAttention if it didn’t exist?”
16.15 FlashAttention-3: Hopper-Specific Optimizations
FlashAttention-2 was architecture-agnostic. FlashAttention-3 exploits Hopper’s unique features for another 1.5-2× speedup.
16.15.1 The Three Innovations
FlashAttention-3 advances:
1. TMA (Tensor Memory Accelerator)
- Hardware-managed async memory transfers
- Replaces software-orchestrated loads
2. Warp Specialization
- Dedicated producer/consumer warps
- Overlaps memory and compute
3. FP8 Support
- 2× throughput vs FP16
- Block-level quantization for accuracy
16.15.2 TMA: Hardware Async Copies
On Ampere (A100), software orchestrates memory transfers:
# A100: Software-managed async copy
# Step 1: Issue load
cp.async.ca.shared.global [smem], [gmem]
# Step 2: Commit group
cp.async.commit_group
# Step 3: Wait
cp.async.wait_group 0
# Software tracks addresses, handles boundariesOn Hopper, TMA handles this in hardware:
# H100: TMA-managed copy
# Step 1: Create TMA descriptor (once)
tma_desc = create_tma_descriptor(
tensor_ptr, shape, strides, tile_shape
)
# Step 2: Issue copy (hardware manages everything)
cp.async.bulk.tensor [smem], [tma_desc], [coords]
# Step 3: Wait on barrier
arrive.expect_tx barrier, bytes
wait barrier
# Hardware handles: addressing, boundaries, caching, coalescingBenefits: - Fewer instructions (freeing registers) - Better memory access patterns - Automatic handling of edge tiles
16.15.3 Warp Specialization
FlashAttention-2: All warps do the same work
Time ─→
Warp 0: [Load Q][Load K][Compute][Load V][Compute][Store]
Warp 1: [Load Q][Load K][Compute][Load V][Compute][Store]
│ │ │ │
└─────┴────────┴─ Synchronization barriers ─┘
FlashAttention-3: Warps specialize into roles
Time ─→
Producer warp: [Load K₀][Load V₀][Load K₁][Load V₁][Load K₂]...
│ │ │ │
▼ ▼ ▼ ▼
Consumer warps: [Compute₀][Compute₀][Compute₁][Compute₁]...
[Store₀───][Store₀───][Store₁───]...
Overlapped! Producer loads next tile while consumers compute current.
Implementation concept:
// Warp specialization in FA3
if (warp_id == PRODUCER_WARP) {
// This warp only does memory operations
for (int i = 0; i < num_tiles; i++) {
// Load next K, V tiles via TMA
tma_load_async(K_smem[i % 2], K_gmem + i * TILE_K);
tma_load_async(V_smem[i % 2], V_gmem + i * TILE_K);
// Signal consumers
arrive(barrier[i % 2]);
}
} else {
// Consumer warps only compute
for (int i = 0; i < num_tiles; i++) {
// Wait for producer
wait(barrier[i % 2]);
// Compute attention for this tile
S = Q @ K_smem[i % 2].T
O_acc = online_softmax_update(O_acc, S, V_smem[i % 2])
}
}16.15.4 Pingpong Scheduling
The pattern above uses pingpong buffers (double buffering) in shared memory:
Shared Memory Layout:
┌──────────────────────────────────────┐
│ K buffer 0 │ V buffer 0 │ │
├──────────────┼──────────────┤ │
│ K buffer 1 │ V buffer 1 │ Q │
└──────────────┴──────────────┴────────┘
Time step 0: Compute with buffer 0, load into buffer 1
Time step 1: Compute with buffer 1, load into buffer 0
Time step 2: Compute with buffer 0, load into buffer 1
...
Memory operations and compute never contend for the same buffer.
16.15.5 FP8 Attention
Hopper’s tensor cores support FP8 for 2× the FLOPS of FP16.
The challenge: FP8 has limited range (±448 for E4M3).
Solution: Block-wise quantization
def fp8_attention(Q, K, V):
# Per-block scaling factors
scale_Q = Q.abs().amax(dim=-1, keepdim=True) / 448.0
scale_K = K.abs().amax(dim=-1, keepdim=True) / 448.0
# Quantize
Q_fp8 = (Q / scale_Q).to(torch.float8_e4m3fn)
K_fp8 = (K / scale_K).to(torch.float8_e4m3fn)
# Compute in FP8 (2× faster)
S_fp8 = Q_fp8 @ K_fp8.T # Tensor core MMA
# Dequantize for softmax (needs range)
S = S_fp8.float() * scale_Q * scale_K.T
# Softmax in FP32 for accuracy
P = softmax(S / sqrt(d))
# V can stay in FP16/BF16
return P @ VAccuracy: Within 0.1% of FP16 attention with proper scaling.
16.15.6 Performance Comparison
FlashAttention variants on H100 (sequence length 8192, head dim 128):
Version Time (ms) TFLOPS % of Peak Key Feature
─────────────────────────────────────────────────────────────
FA1 12.4 180 18% Online softmax
FA2 5.8 385 39% Better parallelism
FA3 FP16 3.2 700 71% TMA + warp spec
FA3 FP8 1.9 1,180 60% (FP8) FP8 tensor cores
Theoretical peak: 989 TFLOPS (FP16), 1979 TFLOPS (FP8)
16.15.7 When to Use Each Version
flowchart TD
A{Hardware?} -->|Hopper H100/H200| B[FA3]
A -->|Ampere A100| C[FA2]
A -->|Older/Non-NVIDIA| D[FA1 or PyTorch SDPA]
B --> E{Precision requirement?}
E -->|Need FP16 accuracy| F[FA3 FP16]
E -->|Can tolerate FP8| G[FA3 FP8<br/>2× faster]
style B fill:#dcfce7,stroke:#16a34a
style C fill:#e0f2fe,stroke:#0284c7
style D fill:#fef3c7,stroke:#d97706
style F fill:#f3e8ff,stroke:#9333ea
style G fill:#dcfce7,stroke:#16a34a
Quick reference:
| Version | Hardware | Best For |
|---|---|---|
| FA1 | Legacy, non-NVIDIA | Compatibility |
| FA2 | A100 and earlier | Production stability |
| FA3 FP16 | H100/H200 | Maximum quality |
| FA3 FP8 | H100/H200 | Maximum speed |
16.15.8 Using FlashAttention-3
# Installation
# pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_func
# FP16 attention
output = flash_attn_func(
q, k, v,
causal=True,
softmax_scale=1.0 / math.sqrt(d)
)
# FP8 attention (Hopper only)
from flash_attn import flash_attn_func_fp8
output = flash_attn_func_fp8(
q.to(torch.float8_e4m3fn),
k.to(torch.float8_e4m3fn),
v, # V can stay FP16
descale_q=scale_q,
descale_k=scale_k,
)16.15.9 The Lesson: Hardware-Algorithm Co-Design
FlashAttention-3 demonstrates that maximum performance requires exploiting specific hardware features:
Generic algorithm (FA1/FA2):
+ Portable
+ Easier to maintain
- Leaves performance on the table
Hardware-specific (FA3):
+ Maximum performance
- Requires Hopper
- More complex implementation
The best systems offer both: portable default, hardware-specific fast paths.
16.16 Connections
Chapter 1 (Memory Hierarchy): FlashAttention is fundamentally about keeping data in fast memory (SRAM) instead of slow memory (HBM).
Chapter 2 (Bandwidth): The algorithm is faster despite more FLOPs because memory bandwidth, not compute, is the bottleneck.
Chapter 4 (Associativity): The (max, sum, output) state forms a monoid—the mathematical foundation for chunking.
16.17 Key Takeaways
The question matters: “Must we materialize the n×n matrix?” led to the algorithm. Without the question, you’d never find the answer.
Look for hidden structure: Softmax doesn’t look associative. But it is, with the right state representation.
Hardware context is essential: FlashAttention is optimized for GPU memory hierarchy. On different hardware, different tradeoffs might apply.
Recomputation is a tool: Trading compute for memory (via recomputation) is sometimes the right trade.
The derivation teaches more than the result: Understanding why FlashAttention works lets you find the next FlashAttention.
16.18 Further Reading
- Dao et al. (2022). “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness”
- Dao (2023). “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning”
- Shah et al. (2024). “FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision”
- Rabe & Staats (2021). “Self-attention Does Not Need O(n²) Memory” - Independent discovery of online softmax
- Milakov & Gimelshein (2018). “Online normalizer calculation for softmax” - The mathematical foundation