2 Thinking in Arrays
The Mental Model That Makes Performance Visible
The same computation, written three ways:
# Version A: 47 seconds
for i in range(n):
for j in range(n):
C[i,j] = sum(A[i,k] * B[k,j] for k in range(n))
# Version B: 0.3 seconds
C = A @ B
# Version C: 0.3 seconds
C = einsum('ik,kj->ij', A, B)All three compute matrix multiplication. Versions B and C are 150× faster.
The difference isn’t just syntax—it’s a different way of thinking about computation.
2.2 The APL Inheritance
The idea of “thinking in arrays” isn’t new. It dates to Kenneth Iverson’s APL (1962), a language where the primitive unit is the array, not the scalar.
APL programmers learned to see problems differently:
| Scalar Thinking | Array Thinking |
|---|---|
| “For each element, do X” | “Apply X to the array” |
| “Loop until condition” | “Mask and reduce” |
| “Accumulate in variable” | “Scan operation” |
This mental shift enabled concise expression of complex algorithms. More importantly, it exposed structure that compilers could exploit.
The lineage continues:
APL (1962) → J/K → MATLAB → NumPy → PyTorch/JAX
Modern deep learning frameworks are array languages. When you write PyTorch or JAX, you’re writing in the APL tradition—whether you know it or not.
2.3 From Loops to Operations
Let’s trace the mental shift through progressively more complex examples.
2.3.1 Example 1: Element-wise Operations
Scalar thinking:
result = []
for x in data:
result.append(x * 2 + 1)Array thinking:
result = data * 2 + 1The array version isn’t just shorter—it’s semantically richer. It says “scale and shift the entire array,” which can be vectorized, parallelized, and fused.
2.3.2 Example 2: Aggregations
Scalar thinking:
total = 0
for x in data:
total += x
mean = total / len(data)Array thinking:
mean = data.mean()The aggregation pattern (reduce with +) is explicit. The runtime can use tree reduction for parallelism, numerical stability tricks, and optimized implementations.
2.3.3 Example 3: Conditional Operations
Scalar thinking:
result = []
for x in data:
if x > 0:
result.append(x)
else:
result.append(0)Array thinking:
result = np.maximum(data, 0) # or: np.where(data > 0, data, 0)The conditional becomes a mask operation—a pattern with known optimizations.
2.3.4 Example 4: Matrix Operations
Scalar thinking:
# Attention scores
scores = np.zeros((n, n))
for i in range(n):
for j in range(n):
for k in range(d):
scores[i,j] += Q[i,k] * K[j,k]Array thinking:
scores = Q @ K.TThe nested loops obscure that this is matrix multiplication. The array version makes it explicit, enabling optimized GEMM kernels.
2.4 The einsum Revolution
Einstein summation notation (einsum) takes array thinking to its logical conclusion: declare the index structure, let the system figure out the computation.
# Matrix multiply
C = einsum('ik,kj->ij', A, B)
# Batch matrix multiply
C = einsum('bik,bkj->bij', A, B)
# Attention scores
S = einsum('bqd,bkd->bqk', Q, K)
# Weighted sum
O = einsum('bqk,bkd->bqd', weights, V)2.4.1 Why einsum Matters
- Explicit index structure: You see exactly which dimensions interact
- Automatic optimization: Libraries like opt_einsum find optimal contraction orders
- Composability: Complex operations are compositions of simple contractions
- Backend flexibility: Same notation works on CPU, GPU, TPU
2.4.2 Contraction Order Matters
Consider a chain of three matrices:
# A: (10, 100), B: (100, 5), C: (5, 50)
result = einsum('ij,jk,kl->il', A, B, C)Two possible contraction orders:
Order 1: (A @ B) @ C
Step 1: (10, 100) @ (100, 5) = (10, 5) → 5,000 multiplies
Step 2: (10, 5) @ (5, 50) = (10, 50) → 2,500 multiplies
Total: 7,500 multiplies
Order 2: A @ (B @ C)
Step 1: (100, 5) @ (5, 50) = (100, 50) → 25,000 multiplies
Step 2: (10, 100) @ (100, 50) = (10, 50) → 50,000 multiplies
Total: 75,000 multiplies
Order 1 is 10× faster. The opt_einsum library automatically finds optimal orderings:
import opt_einsum as oe
# Optimal contraction
path, info = oe.contract_path('ij,jk,kl->il', A, B, C)
result = oe.contract('ij,jk,kl->il', A, B, C, optimize=path)This matters enormously for attention and transformer computations with multiple tensor contractions.
2.5 Memory Layout: Where Data Lives
Array-oriented thinking extends beyond operations to how data is organized in memory.
2.5.1 Array of Structures vs Structure of Arrays
Consider representing particles with position and velocity:
# Array of Structures (AoS)
particles = [
{'x': 1.0, 'y': 2.0, 'vx': 0.1, 'vy': 0.2},
{'x': 3.0, 'y': 4.0, 'vx': 0.3, 'vy': 0.4},
# ...
]
# Structure of Arrays (SoA)
particles = {
'x': np.array([1.0, 3.0, ...]),
'y': np.array([2.0, 4.0, ...]),
'vx': np.array([0.1, 0.3, ...]),
'vy': np.array([0.2, 0.4, ...]),
}Memory layout:
AoS memory: [x₀ y₀ vx₀ vy₀] [x₁ y₁ vx₁ vy₁] [x₂ y₂ vx₂ vy₂] ...
SoA memory: [x₀ x₁ x₂ ...] [y₀ y₁ y₂ ...] [vx₀ vx₁ vx₂ ...] [vy₀ vy₁ vy₂ ...]
2.5.2 Why SoA Wins for Vectorization
If you only need x-coordinates:
AoS access: Load x₀, skip 3, load x₁, skip 3, load x₂, ...
(Strided access, cache lines wasted)
SoA access: Load [x₀ x₁ x₂ x₃ x₄ x₅ x₆ x₇]
(Contiguous access, cache lines fully used)
For GPU memory coalescing, the difference is even starker:
# GPU kernel accessing AoS (slow)
# Each thread loads non-contiguous memory
for i in range(n):
x = particles[i].x # Scattered memory access
# GPU kernel accessing SoA (fast)
# Adjacent threads load adjacent memory
x_batch = particles.x[start:end] # Coalesced memory access2.5.3 Layout in Deep Learning
PyTorch tensors are typically “channels-last” or “channels-first”:
# NCHW (channels-first, PyTorch default)
tensor.shape = (batch, channels, height, width)
# NHWC (channels-last, TensorFlow default, better for some GPU ops)
tensor.shape = (batch, height, width, channels)The “right” layout depends on the operation: - Convolutions often prefer NHWC (better memory access patterns) - Batch normalization prefers NCHW (reduction over spatial dims)
Modern frameworks can convert between layouts, but conversions cost memory bandwidth. Thinking about layout upfront avoids this overhead.
2.6 Transformation Composition
JAX crystallizes array-oriented thinking into composable transformations:
import jax
import jax.numpy as jnp
def loss_fn(params, x, y):
pred = model(params, x)
return jnp.mean((pred - y) ** 2)
# Automatic differentiation
grad_fn = jax.grad(loss_fn)
# Vectorization: process batches automatically
batched_loss = jax.vmap(loss_fn, in_axes=(None, 0, 0))
# JIT compilation
fast_loss = jax.jit(batched_loss)
# Compose them all
fast_grad = jax.jit(jax.grad(lambda p, x, y: jax.vmap(
loss_fn, in_axes=(None, 0, 0)
)(p, x, y).mean()))2.6.1 The vmap Mental Model
vmap (vectorizing map) eliminates batch dimensions from your thinking:
# Without vmap: manually handle batch dimension
def batched_predict(params, X):
return jnp.stack([predict(params, x) for x in X])
# With vmap: think about single examples
@jax.vmap
def batched_predict(params, x):
return predict(params, x) # Written for single exampleThis is profound: you write code for one example, and vmap automatically vectorizes it. The batching structure becomes a transformation, not embedded logic.
2.6.2 Why Composition Matters
Each transformation exposes different structure:
| Transformation | What It Exposes |
|---|---|
jit |
Static computation graph → kernel fusion, memory planning |
grad |
Differentiable structure → backward pass optimization |
vmap |
Parallel structure → SIMD, GPU parallelism |
pmap |
Distributed structure → multi-device parallelism |
By composing transformations, you let the compiler see all the structure at once:
# Compiler sees: "Differentiate a parallelizable, vectorized function"
# Can optimize across all these dimensions simultaneously
optimized = jax.jit(jax.pmap(jax.vmap(jax.grad(f))))2.7 Data-Oriented Design Principles
Mike Acton’s Data-Oriented Design (CppCon 2014) articulates principles that complement array-oriented thinking:
2.7.1 Principle 1: “The purpose of all programs is to transform data”
Not to model objects, manage state, or encapsulate behavior. To transform data from one form to another.
Implication for ML: A neural network is a data transformation pipeline. Think about what flows through it, not what it “is.”
2.7.2 Principle 2: “If you don’t understand the data, you don’t understand the problem”
Before optimizing, understand: - What is the data? (Types, shapes, distributions) - How is it accessed? (Patterns, frequencies, dependencies) - Where does it live? (Memory hierarchy, device placement)
Implication for ML: Profile your data access patterns. The model architecture determines data flow; understanding flow reveals optimization opportunities.
2.7.3 Principle 3: “Different problems have different data”
There’s no universal “best” layout or algorithm. The optimal choice depends on your specific data characteristics.
Implication for ML: - Sparse models need sparse layouts - Sequential processing needs streaming access - Random access needs different optimization than sequential
2.7.4 Principle 4: “Where there is one, there are many”
If you’re processing one thing, you’re probably processing many. Design for batches from the start.
Implication for ML: This is why batch processing dominates. Individual samples are the exception, not the rule.
2.8 Applying the Mindset: A Case Study
Let’s trace how data-oriented thinking leads to FlashAttention.
2.8.1 Step 1: Understand the Data Flow
Standard attention:
S = Q @ K.T # (n, n) attention scores
P = softmax(S) # (n, n) attention weights
O = P @ V # (n, d) outputData sizes for n=4096, d=128:
Q, K, V: 4096 × 128 = 512 KB each
S, P: 4096 × 4096 = 64 MB each
O: 4096 × 128 = 512 KB
Observation: We create 128 MB of intermediate data to produce 512 KB of output.
2.8.2 Step 2: Trace Access Patterns
For each output row O[i]: - Reads all of K (512 KB) - Reads all of V (512 KB) - Reads only Q[i] (128 bytes)
The access pattern is row-wise independent. Each output row can be computed without the others.
2.8.3 Step 3: Ask the Data-Oriented Question
“Do we need to materialize S and P?”
Analysis: - S[i,j] is only used to compute P[i,j] - P[i,:] is only used to compute O[i] - Neither S nor P are outputs—they’re intermediate
Conclusion: S and P are working memory, not essential data. We might eliminate them.
2.8.4 Step 4: Find the Algebraic Enabler
Can softmax be computed incrementally? Let’s check:
Standard softmax: P[i,j] = exp(S[i,j]) / Σⱼ exp(S[i,j])
For a subset of j: partial_sum = Σⱼ∈subset exp(S[i,j])
The denominator is a sum—sums are associative. We can compute partial sums and combine them.
This is the insight: the algebraic property (associativity) becomes visible because we asked the right question (can we eliminate the intermediate?).
2.8.5 Step 5: Design for the Data Flow
FlashAttention reorganizes computation around memory:
# Instead of: compute all of S, then all of P, then O
# Do: for each block of K,V, update running O
for k_block, v_block in blocks(K, V):
partial_scores = Q @ k_block.T # Small: (n, block_size)
partial_weights = stable_softmax_update(partial_scores)
O = update_output(O, partial_weights, v_block)The algorithm follows from the data-oriented analysis: 1. Identified unnecessary intermediates 2. Found algebraic property enabling elimination 3. Reorganized computation around memory access
2.9 The Synthesis
Array-oriented thinking and data-oriented design converge on the same insight:
See the structure of your data and computation. The optimizations follow.
The Algebraic Framework (next chapter) provides the vocabulary for structure: associativity, separability, locality, sparsity, redundancy, symmetry.
This chapter provides the mental model for seeing that structure:
- Think in arrays, not loops → Structure becomes visible
- Consider memory layout → Access patterns become visible
- Compose transformations → Optimization opportunities become visible
- Ask what data flows where → Unnecessary work becomes visible
With both the mental model and the vocabulary, you’re equipped to derive optimizations rather than memorize them.
2.10 Key Takeaways
Loops hide structure. Array operations expose it. Write
A @ B, not nested loops.einsum is declarative algebra. State the index structure; let the system optimize.
Layout determines performance. SoA for vectorization; choose formats that match access patterns.
Transformations compose. vmap, grad, jit reveal different structure. Compose them to expose all of it.
Data-oriented questions:
- What is the data?
- How is it accessed?
- What intermediates are truly necessary?
- What structure can we exploit?