8  Associativity

How Regrouping Enables Parallelism and Streaming

You have a billion numbers. You need their average.

The naive approach: load all billion into memory, sum them, divide.

The problem: a billion float64s is 8 GB. What if you only have 1 GB of RAM?

TipHistorical Note: Egyptian Multiplication (2000 BCE)

The exploitation of associativity for efficiency dates back four millennia. Egyptian scribes computed multiplication using repeated doubling—O(log n) operations instead of O(n) additions. Their insight: because addition is associative, you can regroup (8+4+1)×24 into 8×24 + 4×24 + 1×24, computing each term via doubling.

This is the same principle behind parallel scan, distributed reduction, and FlashAttention’s streaming softmax. The algorithm is ancient; the scale is modern.

8.1 The Property That Enables Everything

Some operations can be computed in any order:

\[a + (b + c) = (a + b) + c\]

This is associativity. It seems abstract, almost trivial. It’s neither.

Associativity is the license to:

  • Chunk: Process data in pieces
  • Parallelize: Combine partial results from multiple workers
  • Stream: Process data as it arrives, without storing everything
  • Checkpoint: Save intermediate state and resume later

Without associativity, you must process everything at once. With it, you can process anything, no matter how large.

8.2 From Addition to Architecture

Let’s trace how this abstract property becomes concrete performance.

8.2.1 Summing a Billion Numbers

The naive sum:

def naive_sum(numbers):
    total = 0
    for x in numbers:
        total += x
    return total

This works, but it requires all numbers in memory. Can we do better?

Associativity says: the grouping doesn’t matter.

# These are mathematically identical:
(a + b + c + d) + (e + f + g + h)  # Two chunks
((a + b) + (c + d)) + ((e + f) + (g + h))  # Four chunks

So we can process chunks:

def chunked_sum(numbers, chunk_size=1000000):
    total = 0
    for chunk in chunks(numbers, chunk_size):
        total += sum(chunk)  # Process chunk, discard
    return total

Same answer. But now we only need chunk_size numbers in memory, not all of them.

8.2.2 The Combinable State

The key insight: we can represent the “state” of a partial computation as something that combines with more data.

For summation, the state is just the running sum. But this pattern generalizes.

Operation State Combine Rule
Sum sum sum₁ + sum₂
Count count count₁ + count₂
Average (sum, count) (sum₁+sum₂, count₁+count₂)
Max max max(max₁, max₂)
Variance (sum, sum_sq, count) Combine pairwise

The pattern: find the state that makes your operation associative.

8.3 Investigation: The Softmax Challenge

Softmax seems sequential. It needs the global maximum for numerical stability:

\[\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}}\]

Naive implementation:

def naive_softmax(x):
    # Shift for numerical stability (subtract max)
    x_max = x.max()  # Need ALL of x
    exp_x = np.exp(x - x_max)
    return exp_x / exp_x.sum()  # Need ALL of exp_x

This requires two passes over all data: 1. Find the maximum 2. Compute exponentials and sum

Can we do it in one pass? Can we stream it?

8.3.1 The Insight: Tracking the Right State

The trick is recognizing that softmax has hidden associative structure.

Consider computing \(\sum_i e^{x_i}\) with numerical stability:

\[\sum_i e^{x_i} = e^{m} \sum_i e^{x_i - m}\]

where \(m = \max_i x_i\).

If we’re streaming and see new elements, the max might change. When it does:

\[e^{m_{old}} \cdot s_{old} = e^{m_{new}} \cdot s_{new}\]

So:

\[s_{new} = s_{old} \cdot e^{m_{old} - m_{new}} + e^{x_{new} - m_{new}}\]

The state is (max, scaled_sum). Here’s the one-pass algorithm:

def streaming_softmax_sum(stream):
    """Compute sum(exp(x)) in one pass, numerically stable."""
    m = float('-inf')  # Running max
    s = 0.0            # Running sum (scaled by current max)

    for x in stream:
        if x > m:
            # Max changed! Rescale the sum.
            s = s * np.exp(m - x) + 1.0
            m = x
        else:
            s = s + np.exp(x - m)

    return m, s  # Final sum is s * exp(m)

Let’s verify this works:

# Test
x = np.array([1.0, 3.0, 2.0, 5.0, 4.0])

# Naive (two-pass)
m_naive = x.max()
s_naive = np.exp(x - m_naive).sum()

# Streaming (one-pass)
m_stream, s_stream = streaming_softmax_sum(x)

print(f"Naive:     max={m_naive}, sum={s_naive:.6f}")
print(f"Streaming: max={m_stream}, sum={s_stream:.6f}")
# Both give: max=5.0, sum=2.486899

8.3.2 The Combine Operation

Now the crucial question: can two partial results be combined?

If we have two chunks with states \((m_1, s_1)\) and \((m_2, s_2)\):

def combine_softmax_states(state1, state2):
    m1, s1 = state1
    m2, s2 = state2
    m = max(m1, m2)

    # Rescale both sums to the new max
    s = s1 * np.exp(m1 - m) + s2 * np.exp(m2 - m)
    return (m, s)

Let’s verify:

# Split the array and combine
x1, x2 = x[:3], x[3:]

# Process each chunk
state1 = streaming_softmax_sum(x1)
state2 = streaming_softmax_sum(x2)

# Combine
m_combined, s_combined = combine_softmax_states(state1, state2)

print(f"Combined:  max={m_combined}, sum={s_combined:.6f}")
# Same result: max=5.0, sum=2.486899

The softmax denominator is associative. Not obviously, but once you find the right state, it combines.

8.4 Preview: From Softmax to FlashAttention

This streaming softmax is the mathematical foundation of FlashAttention.

The key insight: we can extend the (max, sum) state to include the output accumulator. Standard attention needs O(n²) memory to store the attention matrix; FlashAttention uses the streaming approach to reduce this to O(n).

Standard attention memory:
  S = Q @ K.T:  O(n²)  ← The killer
  P (softmax):  O(n²)  ← Also killer

FlashAttention memory:
  State (max, sum, output): O(n × d)
  Per-block intermediates:  O(block_size × d)
  Total:                    O(n × d)

For n = 32,768, d = 128:
  Standard: 4 GB per attention layer
  Flash:    16 MB per attention layer
  Reduction: 256×

Chapter 10 derives FlashAttention in full detail, including:

  • The complete tiled algorithm
  • Why failed approaches (sparse attention, gradient checkpointing) don’t solve this problem
  • The backward pass with recomputation
  • Hardware-specific tuning

The core insight—streaming softmax via (max, sum) state—is what we’ve developed here.

8.5 The General Pattern

Finding associative structure follows a pattern:

  1. Identify what you need at the end: For softmax, you need the normalized probabilities

  2. Ask what state enables incremental update: For softmax, it’s (max, scaled_sum). For attention, it’s (max, scaled_sum, scaled_output)

  3. Derive the correction factor: When state changes (max increases), how do you update? Usually involves a multiplicative correction

  4. Verify the combine operation: Can you merge two partial states? If yes, you have associativity

8.5.1 Examples Beyond Softmax

Online Variance (Welford’s Algorithm)

The naive variance needs two passes: 1. Compute mean 2. Compute squared deviations from mean

But there’s a one-pass algorithm with state (count, mean, M2):

def welford_update(state, x):
    count, mean, M2 = state
    count += 1
    delta = x - mean
    mean += delta / count
    delta2 = x - mean
    M2 += delta * delta2
    return (count, mean, M2)

def welford_combine(state1, state2):
    """Combine two partial variance computations."""
    n1, mean1, M2_1 = state1
    n2, mean2, M2_2 = state2

    n = n1 + n2
    delta = mean2 - mean1
    mean = mean1 + delta * n2 / n
    M2 = M2_1 + M2_2 + delta * delta * n1 * n2 / n

    return (n, mean, M2)

# Variance is M2 / count

Parallel Prefix Sum

Given [a, b, c, d, e, f, g, h], compute running sums [a, a+b, a+b+c, …].

Seems sequential. But associativity enables a parallel algorithm:

Step 1: Pairwise sums
  [a,   b,   c,   d,   e,   f,   g,   h]
   └─+──┘    └─+──┘    └─+──┘    └─+──┘
  [a, a+b,   c, c+d,   e, e+f,   g, g+h]

Step 2: Sums of pairs
  [a, a+b,   c, c+d,   e, e+f,   g, g+h]
       └──────+──┘         └──────+──┘
  [a, a+b,   c, a..d,  e, e+f,   g, e..h]

Step 3: Continue pattern...

This is the basis of GPU parallel scan, enabling O(log n) depth with O(n) work.

Visualize how parallel scan (prefix sum) uses associativity to achieve O(log n) parallel depth. Click “Step” to advance through the algorithm, or use “Auto” to animate.

Key insight: Because addition is associative, we can reorder the computation. Instead of n-1 sequential adds, we use O(log n) parallel rounds. The up-sweep computes partial sums; the down-sweep distributes them—all enabled by the freedom to regroup.

8.6 When Associativity Breaks

Not everything associates. Recognize these cases:

Median: No associative structure. The median of medians is not the global median.

# Counterexample:
chunk1 = [1, 2, 3]    # median = 2
chunk2 = [4, 5, 6]    # median = 5
combined = [1, 2, 3, 4, 5, 6]  # median = 3.5

# median(2, 5) = 3.5? No, we got lucky. Usually it's wrong.

Mode: The mode of modes is not the global mode.

Percentiles: Generally not associative (though there are approximate streaming algorithms).

Floating-Point Gotchas: Mathematically, addition is associative. On computers:

>>> (1e-16 + 1.0) - 1.0
0.0
>>> 1e-16 + (1.0 - 1.0)
1e-16

Floating-point addition is not associative due to rounding. For most purposes, we pretend it is—but be aware when precision matters.

8.7 The Hardware Connection

Associativity’s value comes from how it interacts with hardware:

Memory Hierarchy (Chapter 1): Chunking lets you fit working sets in cache. FlashAttention’s blocks are sized to fit in GPU SRAM.

Bandwidth (Chapter 2): Streaming algorithms read data once rather than multiple times. FlashAttention reduces memory traffic by 4-8× beyond just reducing memory footprint.

Parallelism (Chapter 3): Associative operations enable tree reduction, the fundamental parallel primitive.

The Connection:

Mathematical Property     Hardware Constraint      Exploitation
───────────────────────────────────────────────────────────────
Associativity        →    Limited SRAM         →  Tiled/blocked algorithms
                     →    Memory bandwidth     →  Single-pass streaming
                     →    Parallel cores       →  Tree reduction

8.8 Key Takeaways

  1. Associativity is a license: It permits chunking, streaming, and parallelization

  2. The challenge is finding state: The operation itself might not look associative, but there may be hidden structure. Softmax’s (max, sum) is the canonical example

  3. Correction factors are the key: When state changes (max shifts), the correction factor lets you update without recomputing

  4. The pattern is learnable: Ask “what state would let me combine partial results?” This question guides discovery

  5. Hardware rewards associativity: The property aligns with every level of the memory hierarchy and parallel execution

NoteTry It Yourself

The accompanying notebook lets you:

  • Implement and verify streaming softmax
  • Explore the combine operation
  • Build a simplified FlashAttention from scratch
  • Measure memory savings

Open In Colab

8.9 Further Reading

  • Milakov & Gimelshein (2018). “Online normalizer calculation for softmax” - The mathematical foundation
  • Dao et al. (2022). “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness”
  • Welford (1962). “Note on a method for calculating corrected sums of squares and products” - The classic streaming variance