8 Associativity
How Regrouping Enables Parallelism and Streaming
You have a billion numbers. You need their average.
The naive approach: load all billion into memory, sum them, divide.
The problem: a billion float64s is 8 GB. What if you only have 1 GB of RAM?
TipHistorical Note: Egyptian Multiplication (2000 BCE)
The exploitation of associativity for efficiency dates back four millennia. Egyptian scribes computed multiplication using repeated doubling—O(log n) operations instead of O(n) additions. Their insight: because addition is associative, you can regroup (8+4+1)×24 into 8×24 + 4×24 + 1×24, computing each term via doubling.
This is the same principle behind parallel scan, distributed reduction, and FlashAttention’s streaming softmax. The algorithm is ancient; the scale is modern.
8.1 The Property That Enables Everything
Some operations can be computed in any order:
\[a + (b + c) = (a + b) + c\]
This is associativity. It seems abstract, almost trivial. It’s neither.
Associativity is the license to:
- Chunk: Process data in pieces
- Parallelize: Combine partial results from multiple workers
- Stream: Process data as it arrives, without storing everything
- Checkpoint: Save intermediate state and resume later
Without associativity, you must process everything at once. With it, you can process anything, no matter how large.
8.2 From Addition to Architecture
Let’s trace how this abstract property becomes concrete performance.
8.2.1 Summing a Billion Numbers
The naive sum:
def naive_sum(numbers):
total = 0
for x in numbers:
total += x
return totalThis works, but it requires all numbers in memory. Can we do better?
Associativity says: the grouping doesn’t matter.
# These are mathematically identical:
(a + b + c + d) + (e + f + g + h) # Two chunks
((a + b) + (c + d)) + ((e + f) + (g + h)) # Four chunksSo we can process chunks:
def chunked_sum(numbers, chunk_size=1000000):
total = 0
for chunk in chunks(numbers, chunk_size):
total += sum(chunk) # Process chunk, discard
return totalSame answer. But now we only need chunk_size numbers in memory, not all of them.
8.2.2 The Combinable State
The key insight: we can represent the “state” of a partial computation as something that combines with more data.
For summation, the state is just the running sum. But this pattern generalizes.
| Operation | State | Combine Rule |
|---|---|---|
| Sum | sum |
sum₁ + sum₂ |
| Count | count |
count₁ + count₂ |
| Average | (sum, count) |
(sum₁+sum₂, count₁+count₂) |
| Max | max |
max(max₁, max₂) |
| Variance | (sum, sum_sq, count) |
Combine pairwise |
The pattern: find the state that makes your operation associative.
8.3 Investigation: The Softmax Challenge
Softmax seems sequential. It needs the global maximum for numerical stability:
\[\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}}\]
Naive implementation:
def naive_softmax(x):
# Shift for numerical stability (subtract max)
x_max = x.max() # Need ALL of x
exp_x = np.exp(x - x_max)
return exp_x / exp_x.sum() # Need ALL of exp_xThis requires two passes over all data: 1. Find the maximum 2. Compute exponentials and sum
Can we do it in one pass? Can we stream it?
8.3.1 The Insight: Tracking the Right State
The trick is recognizing that softmax has hidden associative structure.
Consider computing \(\sum_i e^{x_i}\) with numerical stability:
\[\sum_i e^{x_i} = e^{m} \sum_i e^{x_i - m}\]
where \(m = \max_i x_i\).
If we’re streaming and see new elements, the max might change. When it does:
\[e^{m_{old}} \cdot s_{old} = e^{m_{new}} \cdot s_{new}\]
So:
\[s_{new} = s_{old} \cdot e^{m_{old} - m_{new}} + e^{x_{new} - m_{new}}\]
The state is (max, scaled_sum). Here’s the one-pass algorithm:
def streaming_softmax_sum(stream):
"""Compute sum(exp(x)) in one pass, numerically stable."""
m = float('-inf') # Running max
s = 0.0 # Running sum (scaled by current max)
for x in stream:
if x > m:
# Max changed! Rescale the sum.
s = s * np.exp(m - x) + 1.0
m = x
else:
s = s + np.exp(x - m)
return m, s # Final sum is s * exp(m)Let’s verify this works:
# Test
x = np.array([1.0, 3.0, 2.0, 5.0, 4.0])
# Naive (two-pass)
m_naive = x.max()
s_naive = np.exp(x - m_naive).sum()
# Streaming (one-pass)
m_stream, s_stream = streaming_softmax_sum(x)
print(f"Naive: max={m_naive}, sum={s_naive:.6f}")
print(f"Streaming: max={m_stream}, sum={s_stream:.6f}")
# Both give: max=5.0, sum=2.4868998.3.2 The Combine Operation
Now the crucial question: can two partial results be combined?
If we have two chunks with states \((m_1, s_1)\) and \((m_2, s_2)\):
def combine_softmax_states(state1, state2):
m1, s1 = state1
m2, s2 = state2
m = max(m1, m2)
# Rescale both sums to the new max
s = s1 * np.exp(m1 - m) + s2 * np.exp(m2 - m)
return (m, s)Let’s verify:
# Split the array and combine
x1, x2 = x[:3], x[3:]
# Process each chunk
state1 = streaming_softmax_sum(x1)
state2 = streaming_softmax_sum(x2)
# Combine
m_combined, s_combined = combine_softmax_states(state1, state2)
print(f"Combined: max={m_combined}, sum={s_combined:.6f}")
# Same result: max=5.0, sum=2.486899The softmax denominator is associative. Not obviously, but once you find the right state, it combines.
8.4 Preview: From Softmax to FlashAttention
This streaming softmax is the mathematical foundation of FlashAttention.
The key insight: we can extend the (max, sum) state to include the output accumulator. Standard attention needs O(n²) memory to store the attention matrix; FlashAttention uses the streaming approach to reduce this to O(n).
Standard attention memory:
S = Q @ K.T: O(n²) ← The killer
P (softmax): O(n²) ← Also killer
FlashAttention memory:
State (max, sum, output): O(n × d)
Per-block intermediates: O(block_size × d)
Total: O(n × d)
For n = 32,768, d = 128:
Standard: 4 GB per attention layer
Flash: 16 MB per attention layer
Reduction: 256×
Chapter 10 derives FlashAttention in full detail, including:
- The complete tiled algorithm
- Why failed approaches (sparse attention, gradient checkpointing) don’t solve this problem
- The backward pass with recomputation
- Hardware-specific tuning
The core insight—streaming softmax via (max, sum) state—is what we’ve developed here.
8.5 The General Pattern
Finding associative structure follows a pattern:
Identify what you need at the end: For softmax, you need the normalized probabilities
Ask what state enables incremental update: For softmax, it’s (max, scaled_sum). For attention, it’s (max, scaled_sum, scaled_output)
Derive the correction factor: When state changes (max increases), how do you update? Usually involves a multiplicative correction
Verify the combine operation: Can you merge two partial states? If yes, you have associativity
8.5.1 Examples Beyond Softmax
Online Variance (Welford’s Algorithm)
The naive variance needs two passes: 1. Compute mean 2. Compute squared deviations from mean
But there’s a one-pass algorithm with state (count, mean, M2):
def welford_update(state, x):
count, mean, M2 = state
count += 1
delta = x - mean
mean += delta / count
delta2 = x - mean
M2 += delta * delta2
return (count, mean, M2)
def welford_combine(state1, state2):
"""Combine two partial variance computations."""
n1, mean1, M2_1 = state1
n2, mean2, M2_2 = state2
n = n1 + n2
delta = mean2 - mean1
mean = mean1 + delta * n2 / n
M2 = M2_1 + M2_2 + delta * delta * n1 * n2 / n
return (n, mean, M2)
# Variance is M2 / countParallel Prefix Sum
Given [a, b, c, d, e, f, g, h], compute running sums [a, a+b, a+b+c, …].
Seems sequential. But associativity enables a parallel algorithm:
Step 1: Pairwise sums
[a, b, c, d, e, f, g, h]
└─+──┘ └─+──┘ └─+──┘ └─+──┘
[a, a+b, c, c+d, e, e+f, g, g+h]
Step 2: Sums of pairs
[a, a+b, c, c+d, e, e+f, g, g+h]
└──────+──┘ └──────+──┘
[a, a+b, c, a..d, e, e+f, g, e..h]
Step 3: Continue pattern...
This is the basis of GPU parallel scan, enabling O(log n) depth with O(n) work.
TipInteractive: Parallel Scan Visualization
Visualize how parallel scan (prefix sum) uses associativity to achieve O(log n) parallel depth. Click “Step” to advance through the algorithm, or use “Auto” to animate.
Key insight: Because addition is associative, we can reorder the computation. Instead of n-1 sequential adds, we use O(log n) parallel rounds. The up-sweep computes partial sums; the down-sweep distributes them—all enabled by the freedom to regroup.
8.6 When Associativity Breaks
Not everything associates. Recognize these cases:
Median: No associative structure. The median of medians is not the global median.
# Counterexample:
chunk1 = [1, 2, 3] # median = 2
chunk2 = [4, 5, 6] # median = 5
combined = [1, 2, 3, 4, 5, 6] # median = 3.5
# median(2, 5) = 3.5? No, we got lucky. Usually it's wrong.Mode: The mode of modes is not the global mode.
Percentiles: Generally not associative (though there are approximate streaming algorithms).
Floating-Point Gotchas: Mathematically, addition is associative. On computers:
>>> (1e-16 + 1.0) - 1.0
0.0
>>> 1e-16 + (1.0 - 1.0)
1e-16Floating-point addition is not associative due to rounding. For most purposes, we pretend it is—but be aware when precision matters.
8.7 The Hardware Connection
Associativity’s value comes from how it interacts with hardware:
Memory Hierarchy (Chapter 1): Chunking lets you fit working sets in cache. FlashAttention’s blocks are sized to fit in GPU SRAM.
Bandwidth (Chapter 2): Streaming algorithms read data once rather than multiple times. FlashAttention reduces memory traffic by 4-8× beyond just reducing memory footprint.
Parallelism (Chapter 3): Associative operations enable tree reduction, the fundamental parallel primitive.
The Connection:
Mathematical Property Hardware Constraint Exploitation
───────────────────────────────────────────────────────────────
Associativity → Limited SRAM → Tiled/blocked algorithms
→ Memory bandwidth → Single-pass streaming
→ Parallel cores → Tree reduction
8.8 Key Takeaways
Associativity is a license: It permits chunking, streaming, and parallelization
The challenge is finding state: The operation itself might not look associative, but there may be hidden structure. Softmax’s (max, sum) is the canonical example
Correction factors are the key: When state changes (max shifts), the correction factor lets you update without recomputing
The pattern is learnable: Ask “what state would let me combine partial results?” This question guides discovery
Hardware rewards associativity: The property aligns with every level of the memory hierarchy and parallel execution
8.9 Further Reading
- Milakov & Gimelshein (2018). “Online normalizer calculation for softmax” - The mathematical foundation
- Dao et al. (2022). “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness”
- Welford (1962). “Note on a method for calculating corrected sums of squares and products” - The classic streaming variance