16  Investigation: FlashAttention

Deriving the Algorithm That Changed Transformers

Don’t explain FlashAttention. Derive it.

The insight isn’t the algorithm—it’s the question that leads to it.

NoteProperty Spotlight: Associativity

This chapter is a case study in associativity—the first property from our Algebraic Framework.

Softmax-weighted sums can be computed incrementally: \((a ⊕ b) ⊕ c = a ⊕ (b ⊕ c)\). This associative structure is what makes chunking possible. Without it, we’d be forced to materialize the full n×n attention matrix.

The derivation that follows shows how recognizing associativity leads directly to the FlashAttention algorithm.

16.1 The Problem

Attention is the heart of the transformer. But it has a memory problem.

Given queries Q, keys K, and values V (each n × d):

def attention(Q, K, V):
    S = Q @ K.T        # n×n attention scores
    P = softmax(S)     # n×n attention weights
    O = P @ V          # n×d output
    return O

The matrices S and P are n × n. For sequence length 32K and batch size 1:

  • S: 32,768² × 4 bytes = 4 GB
  • P: 32,768² × 4 bytes = 4 GB
  • Total intermediate: 8 GB

Per layer. Per head. Per batch element.

An 80GB A100 runs out of memory quickly.

Explore how FlashAttention reduces memory usage compared to standard attention. Adjust the sequence length and tile size to see how memory requirements scale.

Key insight: FlashAttention processes small tiles that fit in SRAM (shared memory), avoiding the O(n²) HBM memory footprint. The total memory is dominated by the O(n·d) output and statistics, not the attention matrix.

16.2 The Question

Here’s the question that leads to FlashAttention:

Must we materialize the full n×n matrix?

The output O is only n × d. We produce n² intermediate values to compute n × d outputs. That seems wasteful.

Let’s trace what we actually need.

16.3 Understanding the Computation

For a single output row \(O_i\) (the output for query \(i\)):

\[O_i = \sum_j P_{ij} V_j\]

where:

\[P_{ij} = \frac{e^{S_{ij}}}{\sum_k e^{S_{ik}}}\]

and:

\[S_{ij} = Q_i \cdot K_j\]

Expanded:

\[O_i = \frac{\sum_j e^{Q_i \cdot K_j} V_j}{\sum_j e^{Q_i \cdot K_j}}\]

This is a weighted sum of V vectors, weighted by softmax of attention scores.

Observation: For row \(i\), we need all K and all V, but we don’t need the other Q rows.

This suggests: process one query row at a time.

16.4 Attempt 1: Row-by-Row Computation

def attention_row_by_row(Q, K, V):
    n, d = Q.shape
    O = torch.zeros(n, d)

    for i in range(n):
        # Compute scores for row i
        scores = Q[i] @ K.T  # Shape: (n,)

        # Softmax (with stability shift)
        scores_max = scores.max()
        exp_scores = torch.exp(scores - scores_max)
        softmax_denom = exp_scores.sum()
        weights = exp_scores / softmax_denom

        # Output for row i
        O[i] = weights @ V

    return O

Memory: O(n) per row, O(n) total. We’ve eliminated the n² memory.

But wait—this is slow. We’re doing n sequential iterations, each reading all of K and V from memory.

Memory traffic: O(n²d) reads (read K and V once per query).

Standard attention also reads K and V once—the difference is it’s doing a large matrix multiply, which is efficient.

Problem: We’ve traded memory for memory bandwidth. That’s not a good trade.

16.5 The Roads Not Taken

Before finding the right path, let’s examine approaches that seem promising but fall short. Understanding why they fail clarifies what we actually need.

16.5.1 Attempt 2: Gradient Checkpointing

A standard trick for memory-constrained training: don’t store intermediate activations; recompute them during backprop.

def attention_checkpointed(Q, K, V):
    # Forward: compute and store only O
    S = Q @ K.T
    P = softmax(S)
    O = P @ V
    # Don't save S or P—will recompute during backward
    return O

Why it fails: Checkpointing reduces activation memory during training, but not working memory during the forward pass. We still materialize the full n² matrix at some point. The memory spike remains.

Gradient checkpointing is orthogonal to FlashAttention—you can use both together. But it doesn’t solve the forward pass memory problem.

16.5.2 Attempt 3: Sparse Attention

If n² is too expensive, why not make attention sparse?

def sparse_attention(Q, K, V, pattern='local'):
    if pattern == 'local':
        # Only attend to nearby positions
        window = 256
        # ... compute only within window
    elif pattern == 'strided':
        # Attend to every k-th position
        stride = 64
        # ...

Approaches like Longformer, BigBird, and Sparse Transformers use fixed sparsity patterns to achieve O(n) or O(n√n) complexity.

Why it fails for us: Sparse attention is an approximation. It changes what the model computes, not just how it computes. For many tasks, full attention significantly outperforms sparse variants.

We want the exact same output as standard attention—just computed more efficiently.

16.5.3 Attempt 4: Low-Rank Approximation

Maybe we can approximate the attention matrix with something lower-rank?

def linear_attention(Q, K, V):
    # Approximate: instead of softmax(QK^T)V
    # Use: phi(Q) @ (phi(K)^T @ V)
    # where phi is a feature map
    Q_feat = feature_map(Q)  # n×r
    K_feat = feature_map(K)  # n×r
    return Q_feat @ (K_feat.T @ V)  # n×r × r×n × n×d = O(nrd)

Performers, Linear Transformers, and similar approaches use this trick.

Why it fails for us: Again, this is an approximation. The feature map \(\phi\) doesn’t perfectly reproduce softmax behavior. For some tasks it works well; for others, quality degrades significantly.

We want exact attention.

16.5.4 Attempt 5: Sampling-Based Approximation

What if we sample a subset of key-value pairs?

def sampled_attention(Q, K, V, sample_size=256):
    # Randomly sample positions
    idx = np.random.choice(n, sample_size, replace=False)
    K_sampled = K[idx]
    V_sampled = V[idx]
    return softmax(Q @ K_sampled.T) @ V_sampled

Why it fails: Sampling introduces variance. Critical tokens might be missed. Quality is unpredictable.

16.5.5 What We Actually Need

Looking at these failures, the requirements become clear:

  1. Exact computation: Same output as standard attention (bit-for-bit, modulo floating point)
  2. Bounded memory: Don’t materialize the full n² matrix
  3. Efficient execution: Can’t just be correct; must be fast

The first attempt (row-by-row) achieved #1 and #2 but failed #3. The approximation methods achieved #2 and #3 but failed #1.

Is there a solution that achieves all three?

16.6 The Key Insight: Tiles, Not Rows

The row-by-row approach processes one query at a time. But GPU operations are efficient with larger tiles.

New question: Can we process a tile of queries against a tile of keys/values, and combine tiles without recomputation?

This is where Chapter 4’s associativity insight becomes crucial.

16.7 The Softmax Challenge

Standard softmax needs the global maximum for numerical stability:

def softmax(x):
    x_max = x.max()
    exp_x = torch.exp(x - x_max)
    return exp_x / exp_x.sum()

If we process K/V in tiles, we don’t know the global max when processing the first tile.

The blocking issue: Softmax seems to require the full row before producing output.

Let’s solve this.

16.8 Investigation: Online Softmax

What if the maximum changes as we see more data?

Consider two chunks of scores: \([s_1, s_2, s_3]\) and \([s_4, s_5, s_6]\).

Chunk 1: \(m_1 = \max(s_1, s_2, s_3)\), \(d_1 = \sum_{i=1}^{3} e^{s_i - m_1}\)

Chunk 2: \(m_2 = \max(s_4, s_5, s_6)\), \(d_2 = \sum_{i=4}^{6} e^{s_i - m_2}\)

Global: \(m = \max(m_1, m_2)\), \(d = ?\)

The denominator in chunk 1 was computed relative to \(m_1\). To combine with chunk 2, we need to rescale:

\[d_1' = d_1 \cdot e^{m_1 - m}\] \[d_2' = d_2 \cdot e^{m_2 - m}\] \[d = d_1' + d_2'\]

The correction factor: When the max changes, multiply old sums by \(e^{\text{old\_max} - \text{new\_max}}\).

Let’s verify:

import numpy as np

# Full computation
scores = np.array([1.0, 3.0, 2.0, 5.0, 4.0, 3.5])
m_full = scores.max()
d_full = np.exp(scores - m_full).sum()
print(f"Full: max={m_full}, denom={d_full:.4f}")

# Chunked computation
chunk1 = scores[:3]
chunk2 = scores[3:]

m1 = chunk1.max()
d1 = np.exp(chunk1 - m1).sum()
print(f"Chunk 1: max={m1}, denom={d1:.4f}")

m2 = chunk2.max()
d2 = np.exp(chunk2 - m2).sum()
print(f"Chunk 2: max={m2}, denom={d2:.4f}")

# Combine
m = max(m1, m2)
d = d1 * np.exp(m1 - m) + d2 * np.exp(m2 - m)
print(f"Combined: max={m}, denom={d:.4f}")

# Output:
# Full: max=5.0, denom=3.7398
# Chunk 1: max=3.0, denom=2.1353
# Chunk 2: max=5.0, denom=2.1353
# Combined: max=5.0, denom=3.7398  ✓ Matches!

The combination works. The softmax denominator has associative structure with state (max, scaled_sum).

16.9 Extending to Output

The denominator is only half the story. We also need the output:

\[O_i = \frac{\sum_j e^{S_{ij}} V_j}{\sum_j e^{S_{ij}}}\]

The numerator is also a sum of exponentials, but weighted by \(V_j\).

State: (max, sum, numerator) = (m, d, o)

When max changes:

def update_state(old_state, new_scores, new_V):
    m_old, d_old, o_old = old_state

    m_new = new_scores.max()
    exp_new = np.exp(new_scores - m_new)
    d_new = exp_new.sum()
    o_new = exp_new @ new_V

    m = max(m_old, m_new)

    # Rescale old state
    scale_old = np.exp(m_old - m)
    scale_new = np.exp(m_new - m)

    d = d_old * scale_old + d_new * scale_new
    o = o_old * scale_old + o_new * scale_new

    return (m, d, o)

# Final output: o / d

Let’s verify this works:

# Test data
np.random.seed(42)
Q = np.random.randn(4, 8)  # 4 queries, dim 8
K = np.random.randn(6, 8)  # 6 keys
V = np.random.randn(6, 8)  # 6 values

# Standard attention (for reference)
def standard_attention(Q, K, V):
    S = Q @ K.T
    S_max = S.max(axis=1, keepdims=True)
    exp_S = np.exp(S - S_max)
    P = exp_S / exp_S.sum(axis=1, keepdims=True)
    return P @ V

O_standard = standard_attention(Q, K, V)

# Chunked attention
def chunked_attention(Q, K, V, chunk_size=2):
    n = Q.shape[0]
    n_kv = K.shape[0]
    d = V.shape[1]

    # Initialize state for each query
    m = np.full(n, -np.inf)
    s = np.zeros(n)
    o = np.zeros((n, d))

    for j in range(0, n_kv, chunk_size):
        K_chunk = K[j:j+chunk_size]
        V_chunk = V[j:j+chunk_size]

        # Scores for this chunk
        scores = Q @ K_chunk.T  # (n, chunk_size)

        for i in range(n):
            row_scores = scores[i]
            row_max = row_scores.max()

            if row_max > m[i]:
                # Rescale old state
                scale = np.exp(m[i] - row_max)
                s[i] = s[i] * scale
                o[i] = o[i] * scale
                m[i] = row_max

            # Add new contribution
            exp_scores = np.exp(row_scores - m[i])
            s[i] += exp_scores.sum()
            o[i] += exp_scores @ V_chunk

    return o / s[:, None]

O_chunked = chunked_attention(Q, K, V, chunk_size=2)

# Compare
print("Max difference:", np.abs(O_standard - O_chunked).max())
# Output: Max difference: 1.11e-15 (floating point precision)

It works. We can compute attention in chunks without materializing the full n×n matrix.

16.10 The Full Algorithm

Now we tile both Q and K/V:

def flash_attention(Q, K, V, tile_size=64):
    """
    FlashAttention: tiled attention with O(n) memory.

    Tiles Q (rows) and K/V (columns) to fit in SRAM.
    """
    n, d = Q.shape
    n_tiles = (n + tile_size - 1) // tile_size

    # Output and running statistics
    O = np.zeros((n, d))
    M = np.full(n, -np.inf)  # Running max
    L = np.zeros(n)           # Running sum (ell for "l"og-sum-exp)

    # Process K/V in tiles
    for j in range(0, n, tile_size):
        j_end = min(j + tile_size, n)
        K_tile = K[j:j_end]
        V_tile = V[j:j_end]

        # Process Q in tiles (for parallelism)
        for i in range(0, n, tile_size):
            i_end = min(i + tile_size, n)
            Q_tile = Q[i:i_end]

            # Compute attention scores for this tile pair
            S_tile = Q_tile @ K_tile.T  # (tile, tile) - fits in SRAM!

            # Row-wise operations
            for row in range(i_end - i):
                global_row = i + row
                row_scores = S_tile[row]
                row_max = row_scores.max()

                # New max for this row
                new_max = max(M[global_row], row_max)

                # Rescale old statistics
                scale_old = np.exp(M[global_row] - new_max)
                scale_new = np.exp(row_max - new_max)

                exp_scores = np.exp(row_scores - row_max)

                # Update running state
                L[global_row] = L[global_row] * scale_old + exp_scores.sum() * scale_new
                O[global_row] = O[global_row] * scale_old + (exp_scores @ V_tile) * scale_new
                M[global_row] = new_max

    # Normalize
    return O / L[:, None]

16.10.1 Memory Analysis

Standard attention:
  S = Q @ K.T:  O(n²) memory
  P = softmax:  O(n²) memory
  O = P @ V:    O(nd) memory

FlashAttention:
  Q_tile:      O(tile × d) memory
  K_tile:      O(tile × d) memory
  V_tile:      O(tile × d) memory
  S_tile:      O(tile²) memory
  O, M, L:     O(nd) memory

Total: O(tile² + nd)

For tile = 128, d = 64, n = 32768:
  Standard:    32768² × 4 = 4 GB
  Flash:       128² × 4 + 32768 × 64 × 4 = 65KB + 8MB ≈ 8 MB
  Reduction:   ~500×

16.10.2 Speed Analysis

Surprisingly, FlashAttention is often faster despite doing more arithmetic.

Why? IO is the bottleneck.

Standard attention: 1. Read Q, K from HBM → Compute S → Write S to HBM 2. Read S from HBM → Softmax → Write P to HBM 3. Read P, V from HBM → Compute O → Write O to HBM

Total HBM accesses: O(n² + n² + n²) = O(3n²)

FlashAttention: 1. Load Q_tile, K_tile, V_tile from HBM to SRAM 2. Compute everything in SRAM 3. Write O_tile to HBM

Total HBM accesses: O(nd) for loading Q, K, V, O once

The reduction in memory traffic often exceeds the extra compute cost.

16.11 The Tiling Strategy

How do we choose tile sizes?

Constraint: Tiles must fit in SRAM.

GPU SRAM (shared memory): ~100-200 KB per SM

A100 shared memory: 164 KB

For FP16:
  Q_tile: tile × d × 2 bytes
  K_tile: tile × d × 2 bytes
  V_tile: tile × d × 2 bytes
  S_tile: tile × tile × 2 bytes
  Accumulators: tile × d × 4 bytes (FP32 for precision)

For tile = 128, d = 64:
  Q_tile: 128 × 64 × 2 = 16 KB
  K_tile: 128 × 64 × 2 = 16 KB
  V_tile: 128 × 64 × 2 = 16 KB
  S_tile: 128 × 128 × 2 = 32 KB
  Accumulators: 128 × 64 × 4 = 32 KB
  Total: ~112 KB ✓ Fits!

The actual FlashAttention implementation tunes tile sizes per GPU architecture.

16.12 The Backward Pass

Training requires gradients. Can we backpropagate through FlashAttention?

Standard backprop would require storing all intermediate activations—defeating the memory savings.

FlashAttention’s solution: recompute instead of store.

During backward: 1. Reload Q, K, V tiles (from HBM) 2. Recompute S_tile, P_tile (in SRAM) 3. Compute gradients using recomputed values 4. Write gradients to HBM

This trades compute for memory. For large n, the trade is favorable—memory is the bottleneck.

def flash_attention_backward(dO, Q, K, V, O, M, L):
    """
    Backward pass, recomputing forward values.

    dO: gradient of loss w.r.t. output
    O, M, L: saved from forward pass (O(n) memory)
    """
    n, d = Q.shape
    dQ = np.zeros_like(Q)
    dK = np.zeros_like(K)
    dV = np.zeros_like(V)

    for j in range(0, n, tile_size):
        K_tile = K[j:j+tile_size]
        V_tile = V[j:j+tile_size]

        for i in range(0, n, tile_size):
            Q_tile = Q[i:i+tile_size]
            dO_tile = dO[i:i+tile_size]

            # Recompute forward
            S_tile = Q_tile @ K_tile.T
            P_tile = softmax_with_stats(S_tile, M[i:i+tile_size], L[i:i+tile_size])

            # Backward through P @ V
            dV[j:j+tile_size] += P_tile.T @ dO_tile
            dP_tile = dO_tile @ V_tile.T

            # Backward through softmax
            dS_tile = softmax_backward(dP_tile, P_tile)

            # Backward through S = Q @ K.T
            dQ[i:i+tile_size] += dS_tile @ K_tile
            dK[j:j+tile_size] += dS_tile.T @ Q_tile

    return dQ, dK, dV

16.13 Benchmarks

Real-world performance on an A100:

Sequence length: 2048, head dim: 64

Method                Memory (MB)    Time (ms)
───────────────────────────────────────────────
Standard attention    1,024          2.1
FlashAttention        8              0.8

Sequence length: 16384, head dim: 64

Method                Memory (MB)    Time (ms)
───────────────────────────────────────────────
Standard attention    OOM            —
FlashAttention        64             42

Sequence length: 65536, head dim: 64

Method                Memory (MB)    Time (ms)
───────────────────────────────────────────────
Standard attention    OOM            —
FlashAttention        256            680

FlashAttention enables sequences that were previously impossible, while being faster on feasible sequences.

16.14 The Derivation Pattern

Let’s trace how we derived FlashAttention:

  1. Identify the problem: O(n²) memory from materializing attention matrix

  2. Ask the key question: Must we materialize it? What do we actually need?

  3. Discover the structure: Softmax has associative structure via (max, sum) state

  4. Extend the insight: Output accumulation also has this structure with (max, sum, output_sum)

  5. Apply hardware constraints: Tile to fit in SRAM, minimize HBM traffic

  6. Handle the full system: Backward pass recomputes to maintain memory savings

This is the investigation pattern. Not “here’s FlashAttention” but “how would you find FlashAttention if it didn’t exist?”

16.15 FlashAttention-3: Hopper-Specific Optimizations

FlashAttention-2 was architecture-agnostic. FlashAttention-3 exploits Hopper’s unique features for another 1.5-2× speedup.

16.15.1 The Three Innovations

FlashAttention-3 advances:

1. TMA (Tensor Memory Accelerator)
   - Hardware-managed async memory transfers
   - Replaces software-orchestrated loads

2. Warp Specialization
   - Dedicated producer/consumer warps
   - Overlaps memory and compute

3. FP8 Support
   - 2× throughput vs FP16
   - Block-level quantization for accuracy

16.15.2 TMA: Hardware Async Copies

On Ampere (A100), software orchestrates memory transfers:

# A100: Software-managed async copy
# Step 1: Issue load
cp.async.ca.shared.global [smem], [gmem]

# Step 2: Commit group
cp.async.commit_group

# Step 3: Wait
cp.async.wait_group 0

# Software tracks addresses, handles boundaries

On Hopper, TMA handles this in hardware:

# H100: TMA-managed copy
# Step 1: Create TMA descriptor (once)
tma_desc = create_tma_descriptor(
    tensor_ptr, shape, strides, tile_shape
)

# Step 2: Issue copy (hardware manages everything)
cp.async.bulk.tensor [smem], [tma_desc], [coords]

# Step 3: Wait on barrier
arrive.expect_tx barrier, bytes
wait barrier

# Hardware handles: addressing, boundaries, caching, coalescing

Benefits: - Fewer instructions (freeing registers) - Better memory access patterns - Automatic handling of edge tiles

16.15.3 Warp Specialization

FlashAttention-2: All warps do the same work

Time ─→
Warp 0: [Load Q][Load K][Compute][Load V][Compute][Store]
Warp 1: [Load Q][Load K][Compute][Load V][Compute][Store]
         │     │        │                 │
         └─────┴────────┴─ Synchronization barriers ─┘

FlashAttention-3: Warps specialize into roles

Time ─→
Producer warp:  [Load K₀][Load V₀][Load K₁][Load V₁][Load K₂]...
                    │        │        │        │
                    ▼        ▼        ▼        ▼
Consumer warps: [Compute₀][Compute₀][Compute₁][Compute₁]...
                [Store₀───][Store₀───][Store₁───]...

Overlapped! Producer loads next tile while consumers compute current.

Implementation concept:

// Warp specialization in FA3
if (warp_id == PRODUCER_WARP) {
    // This warp only does memory operations
    for (int i = 0; i < num_tiles; i++) {
        // Load next K, V tiles via TMA
        tma_load_async(K_smem[i % 2], K_gmem + i * TILE_K);
        tma_load_async(V_smem[i % 2], V_gmem + i * TILE_K);

        // Signal consumers
        arrive(barrier[i % 2]);
    }
} else {
    // Consumer warps only compute
    for (int i = 0; i < num_tiles; i++) {
        // Wait for producer
        wait(barrier[i % 2]);

        // Compute attention for this tile
        S = Q @ K_smem[i % 2].T
        O_acc = online_softmax_update(O_acc, S, V_smem[i % 2])
    }
}

16.15.4 Pingpong Scheduling

The pattern above uses pingpong buffers (double buffering) in shared memory:

Shared Memory Layout:

┌──────────────────────────────────────┐
│  K buffer 0  │  V buffer 0  │        │
├──────────────┼──────────────┤        │
│  K buffer 1  │  V buffer 1  │  Q     │
└──────────────┴──────────────┴────────┘

Time step 0: Compute with buffer 0, load into buffer 1
Time step 1: Compute with buffer 1, load into buffer 0
Time step 2: Compute with buffer 0, load into buffer 1
...

Memory operations and compute never contend for the same buffer.

16.15.5 FP8 Attention

Hopper’s tensor cores support FP8 for 2× the FLOPS of FP16.

The challenge: FP8 has limited range (±448 for E4M3).

Solution: Block-wise quantization

def fp8_attention(Q, K, V):
    # Per-block scaling factors
    scale_Q = Q.abs().amax(dim=-1, keepdim=True) / 448.0
    scale_K = K.abs().amax(dim=-1, keepdim=True) / 448.0

    # Quantize
    Q_fp8 = (Q / scale_Q).to(torch.float8_e4m3fn)
    K_fp8 = (K / scale_K).to(torch.float8_e4m3fn)

    # Compute in FP8 (2× faster)
    S_fp8 = Q_fp8 @ K_fp8.T  # Tensor core MMA

    # Dequantize for softmax (needs range)
    S = S_fp8.float() * scale_Q * scale_K.T

    # Softmax in FP32 for accuracy
    P = softmax(S / sqrt(d))

    # V can stay in FP16/BF16
    return P @ V

Accuracy: Within 0.1% of FP16 attention with proper scaling.

16.15.6 Performance Comparison

FlashAttention variants on H100 (sequence length 8192, head dim 128):

Version     Time (ms)   TFLOPS   % of Peak   Key Feature
─────────────────────────────────────────────────────────────
FA1         12.4        180      18%         Online softmax
FA2         5.8         385      39%         Better parallelism
FA3 FP16    3.2         700      71%         TMA + warp spec
FA3 FP8     1.9         1,180    60% (FP8)   FP8 tensor cores

Theoretical peak: 989 TFLOPS (FP16), 1979 TFLOPS (FP8)

16.15.7 When to Use Each Version

flowchart TD
    A{Hardware?} -->|Hopper H100/H200| B[FA3]
    A -->|Ampere A100| C[FA2]
    A -->|Older/Non-NVIDIA| D[FA1 or PyTorch SDPA]

    B --> E{Precision requirement?}
    E -->|Need FP16 accuracy| F[FA3 FP16]
    E -->|Can tolerate FP8| G[FA3 FP8<br/>2× faster]

    style B fill:#dcfce7,stroke:#16a34a
    style C fill:#e0f2fe,stroke:#0284c7
    style D fill:#fef3c7,stroke:#d97706
    style F fill:#f3e8ff,stroke:#9333ea
    style G fill:#dcfce7,stroke:#16a34a

Choosing the right FlashAttention version

Quick reference:

Version Hardware Best For
FA1 Legacy, non-NVIDIA Compatibility
FA2 A100 and earlier Production stability
FA3 FP16 H100/H200 Maximum quality
FA3 FP8 H100/H200 Maximum speed

16.15.8 Using FlashAttention-3

# Installation
# pip install flash-attn --no-build-isolation

from flash_attn import flash_attn_func

# FP16 attention
output = flash_attn_func(
    q, k, v,
    causal=True,
    softmax_scale=1.0 / math.sqrt(d)
)

# FP8 attention (Hopper only)
from flash_attn import flash_attn_func_fp8

output = flash_attn_func_fp8(
    q.to(torch.float8_e4m3fn),
    k.to(torch.float8_e4m3fn),
    v,  # V can stay FP16
    descale_q=scale_q,
    descale_k=scale_k,
)

16.15.9 The Lesson: Hardware-Algorithm Co-Design

FlashAttention-3 demonstrates that maximum performance requires exploiting specific hardware features:

Generic algorithm (FA1/FA2):
  + Portable
  + Easier to maintain
  - Leaves performance on the table

Hardware-specific (FA3):
  + Maximum performance
  - Requires Hopper
  - More complex implementation

The best systems offer both: portable default, hardware-specific fast paths.

16.16 Connections

Chapter 1 (Memory Hierarchy): FlashAttention is fundamentally about keeping data in fast memory (SRAM) instead of slow memory (HBM).

Chapter 2 (Bandwidth): The algorithm is faster despite more FLOPs because memory bandwidth, not compute, is the bottleneck.

Chapter 4 (Associativity): The (max, sum, output) state forms a monoid—the mathematical foundation for chunking.

16.17 Key Takeaways

  1. The question matters: “Must we materialize the n×n matrix?” led to the algorithm. Without the question, you’d never find the answer.

  2. Look for hidden structure: Softmax doesn’t look associative. But it is, with the right state representation.

  3. Hardware context is essential: FlashAttention is optimized for GPU memory hierarchy. On different hardware, different tradeoffs might apply.

  4. Recomputation is a tool: Trading compute for memory (via recomputation) is sometimes the right trade.

  5. The derivation teaches more than the result: Understanding why FlashAttention works lets you find the next FlashAttention.

NoteTry It Yourself

The accompanying notebook walks through:

  • Implementing online softmax from scratch
  • Building a simplified FlashAttention
  • Comparing memory usage to standard attention
  • Profiling SRAM vs. HBM access patterns

Open In Colab

16.18 Further Reading

  • Dao et al. (2022). “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness”
  • Dao (2023). “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning”
  • Shah et al. (2024). “FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision”
  • Rabe & Staats (2021). “Self-attention Does Not Need O(n²) Memory” - Independent discovery of online softmax
  • Milakov & Gimelshein (2018). “Online normalizer calculation for softmax” - The mathematical foundation