The One Property That Makes FlashAttention Possible
Originally published on Software Bits — subscribe for updates.
FlashAttention is everywhere.
It’s in PyTorch, JAX, and every LLM serving stack. It delivers 2-4x speedups and cuts memory from O(n²) to O(n).
But ask practitioners why it works, and you get hand-wavy answers about “tiling” and “recomputation.”
Those are implementation details. Today I want to show you the actual why—the mathematical property that makes it all possible.
We’re not starting with attention though. We’re starting with something simpler: computing an average.
A Simple Problem: The Average
You have a billion numbers. You want their average.
The obvious approach:
total = 0
count = 0
for x in all_billion_numbers:
total += x
count += 1
average = total / count
One number at a time. One core. One long wait.
But here’s the thing: you don’t have to process them in order.
Split the data into chunks. Compute each chunk’s sum and count separately:
Chunk A: sum = 1,234,567 count = 500,000,000
Chunk B: sum = 2,345,678 count = 500,000,000
Then combine:
total = sum_a + sum_b = 3,580,245
count = count_a + count_b = 1,000,000,000
average = total / count
Same answer. But now you can process chunks in parallel.
Why This Works: Associativity
This works because addition is associative:
\[(a + b) + c = a + (b + c)\]Grouping doesn’t change the result.
More precisely: the pair (sum, count) can be combined, and that combination is associative. This single property unlocks three capabilities:
┌─────────────────────────────────────────────────────────────────┐
│ │
│ ASSOCIATIVITY │
│ combine(combine(A,B), C) = combine(A, combine(B,C)) │
│ │ │
│ ┌────────────┼────────────┐ │
│ ▼ ▼ ▼ │
│ ┌────────┐ ┌────────┐ ┌────────┐ │
│ │CHUNKING│ │PARALLEL│ │STREAMING│ │
│ └────────┘ └────────┘ └────────┘ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ Process Process Process │
│ pieces pieces at pieces as │
│ separately same time they arrive │
│ │
└─────────────────────────────────────────────────────────────────┘
This might seem obvious for averages. But the same principle—find combinable state, verify associativity—applies to surprisingly complex operations.
Let’s try something harder.
Why Softmax Seems Different
Softmax is the heart of attention:
\[\text{softmax}(x)_i = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]To compute any output, you need the sum over all elements.
If I show you half the values, you can’t produce final outputs. You’re missing half the denominator. Unlike our average, where partial sums combine cleanly, softmax seems to require everything upfront.
Can we find hidden associative structure anyway?
Before we answer, we need to address a practical problem.
The Overflow Problem
Here’s what breaks naive softmax:
>>> import numpy as np
>>> np.exp(1000)
inf
When values are large, exp() overflows to infinity. And attention scores can be large.
The fix exploits a mathematical identity—softmax is translation-invariant:
\[\text{softmax}(x) = \text{softmax}(x - c) \quad \text{for any constant } c\]Subtracting the same value from every element doesn’t change the output. So we subtract the maximum:
\[\text{softmax}(x)_i = \frac{\exp(x_i - \max(x))}{\sum_j \exp(x_j - \max(x))}\]Now every exponent is ≤ 0, so exp() never exceeds 1. No overflow.
This is why you see max everywhere in softmax code. It’s not algorithmic—it’s survival.
But now we have two things to track: the max and the sum. Let’s see if they combine.
The State for Softmax
For our average, the state was (sum, count).
For softmax, given values $x_1, x_2, \ldots, x_n$, we need:
- m = max of all values
- s = sum of $\exp(x_i - m)$
Can we combine two such pairs from separate chunks?
Why Simple Addition Fails
Say we have:
Chunk A: values [1, 2, 3]
m_a = 3
s_a = exp(1-3) + exp(2-3) + exp(3-3) ≈ 1.50
Chunk B: values [4, 5]
m_b = 5
s_b = exp(4-5) + exp(5-5) ≈ 1.37
The combined max is max(3, 5) = 5.
But we can’t just add the sums.
Chunk A’s sum was computed relative to max=3. Chunk B’s sum relative to max=5. They’re in different “units.”
Adding 1.50 + 1.37 = 2.87 would be wrong.
The Correction Factor
Here’s the key insight.
Chunk A computed $s_a = \sum \exp(x - 3)$.
But relative to the true max of 5, it should be $\sum \exp(x - 5)$.
These relate by:
\[\exp(x - 5) = \exp(x - 3) \times \exp(3 - 5) = \exp(x - 3) \times \exp(-2)\]So we correct chunk A’s sum:
\[s_a^{\text{corrected}} = s_a \times \exp(m_a - m_{\text{combined}}) = 1.50 \times \exp(3 - 5) \approx 0.20\]Chunk B needs no correction (its max equals the combined max):
\[s_b^{\text{corrected}} = 1.37 \times \exp(5 - 5) = 1.37 \times 1 = 1.37\]Combined:
\[s_{\text{combined}} = 0.20 + 1.37 = 1.57\]┌─────────────────────────────────────────────────────────────────────┐
│ │
│ CHUNK A CHUNK B │
│ values: [1, 2, 3] values: [4, 5] │
│ ┌─────────────┐ ┌─────────────┐ │
│ │ m_a = 3 │ │ m_b = 5 │ │
│ │ s_a = 1.50 │ │ s_b = 1.37 │ │
│ └──────┬──────┘ └──────┬──────┘ │
│ │ │ │
│ │ ┌───────────────────────┐ │ │
│ └───▶│ m_combined = max(3,5) = 5 │◀───┘ │
│ └───────────────────────┘ │
│ │ │
│ ┌───────────────┴───────────────┐ │
│ ▼ ▼ │
│ ┌─────────────────┐ ┌─────────────────┐ │
│ │ CORRECTION │ │ NO CORRECTION │ │
│ │ 1.50 × exp(3-5) │ │ 1.37 × exp(5-5) │ │
│ │ = 0.20 │ │ = 1.37 │ │
│ └────────┬────────┘ └────────┬────────┘ │
│ │ │ │
│ └──────────┬──────────────────┘ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ s_combined = │ │
│ │ 0.20 + 1.37 │ │
│ │ = 1.57 ✓ │ │
│ └─────────────────┘ │
│ │
│ The correction factor exp(m_local - m_global) rescales │
│ the local sum to the global reference frame. │
│ │
└─────────────────────────────────────────────────────────────────────┘
Let’s verify by computing directly on [1, 2, 3, 4, 5]:
max = 5
sum = exp(-4) + exp(-3) + exp(-2) + exp(-1) + exp(0)
= 0.02 + 0.05 + 0.14 + 0.37 + 1.00
= 1.57 ✓
Same answer. The combination works.
The Combine Rule
We’ve derived the combination for softmax state:
def combine(state_a, state_b):
m_a, s_a = state_a
m_b, s_b = state_b
m = max(m_a, m_b)
s = s_a * exp(m_a - m) + s_b * exp(m_b - m)
return (m, s)
Is this associative?
Yes. The algebra works out:
\[\text{combine}(\text{combine}(A, B), C) = \text{combine}(A, \text{combine}(B, C))\]The pair (max, scaled_sum) forms a monoid—same structure as (sum, count) for averages, just with a more complex combination.
The Pattern, Generalized
| Average | Softmax | |
|---|---|---|
| State | (sum, count) | (max, scaled_sum) |
| Combine | Add both | Max + corrected sum |
| Correction? | None needed | Rescale for new max |
| Associative? | Yes | Yes |
| Parallelizable? | Yes | Yes |
Same structure. Different complexity.
The softmax case requires correction because changing the max changes the meaning of the sum. But the fundamental pattern—combinable state, associative operation—is identical.
FlashAttention: Applying the Pattern
FlashAttention applies this to attention:
\[\text{Attention}(Q, K, V) = \text{softmax}(QK^T) \times V\]Instead of the full n×n matrix, it:
- Processes tiles of K and V one at a time
- Maintains running state: (max, sum, output)
- Applies correction when a new tile reveals a larger max
- Never materializes the full matrix
# FlashAttention Core Loop
# Initialize running state
m = -infinity # running max
s = 0 # running sum
output = 0 # running output
for each tile (K_tile, V_tile):
# Compute attention scores for this tile
scores = Q @ K_tile.T
# Find new max
m_new = max(m, scores.max())
# Correction factor
correction = exp(m - m_new)
# Rescale previous accumulations
output = output * correction
s = s * correction
# Add this tile's contribution
weights = exp(scores - m_new)
s = s + weights.sum()
output = output + weights @ V_tile
# Update max
m = m_new
# Final normalization
output = output / s
Every tile might reveal a larger max. The correction factor adjusts all previous work. Nothing needs recomputation—the associative structure guarantees consistency.
The Skill to Develop
FlashAttention’s genius isn’t tiling or memory tricks.
It’s recognizing that softmax has hidden associative structure.
This is a learnable skill. When facing a “global” computation:
1. What state do I need?
- Average: (sum, count)
- Softmax: (max, sum)
- Your problem: (?, ?)
2. Can partial states combine?
- Write the combine function
- Check: does order matter?
3. Is correction needed?
- Does new information change the meaning of old state?
- If so, what’s the adjustment factor?
4. Is it associative?
- If yes: parallelize, chunk, stream
- If no: fundamentally different algorithm needed
The Counter-Example
Not everything has this structure.
Median:
median([1, 2, 3]) = 2
median([4, 5, 6]) = 5
median([1, 2, 3, 4, 5, 6]) = 3.5
No way to derive 3.5 from 2 and 5. No combinable state exists.
That’s why median requires sorting or selection—fundamentally different algorithms.
Knowing what can’t be expressed this way is as valuable as knowing what can.
The Takeaway
FlashAttention works because of a mathematical property.
Softmax has associative structure hidden beneath its global-looking surface. The (max, scaled_sum) pair combines across chunks. The combination requires correction, but it’s still associative.
Associativity is the license to parallelize, chunk, and stream.
The property came first. The algorithm followed.
This is the difference between collecting tricks and understanding foundations. Tricks solve one problem. Properties solve a category of problems.
When you face a seemingly global computation, don’t reach for a trick.
Reach for the algebra.
See also: Commutativity: Why Transformers Need Positional Encodings