19 Investigation: State-Space Models
Linear-Time Alternatives to Attention
Attention is O(n²). For a million tokens, that’s a trillion operations per layer.
State-space models achieve O(n) complexity. What’s the catch?
19.1 The Attention Problem
Transformers dominate language modeling, but attention has fundamental scaling issues:
Attention complexity per layer:
Sequence Memory FLOPs
───────────────────────────────
1K 4 MB 2 billion
4K 64 MB 32 billion
16K 1 GB 512 billion
64K 16 GB 8 trillion
256K 256 GB 128 trillion
Even with FlashAttention (reduces memory), compute is still O(n²).
State-space models (SSMs) promise O(n) complexity with competitive quality.
19.2 The RNN Connection
RNNs are naturally O(n)—each token updates a fixed-size hidden state:
def rnn_forward(x, h0, W_h, W_x):
"""
Simple RNN: O(n) in sequence length.
"""
h = h0
outputs = []
for t in range(len(x)):
h = tanh(W_h @ h + W_x @ x[t]) # O(d²) per step
outputs.append(h)
return outputs # Total: O(n * d²)Problem: RNNs are hard to train (vanishing gradients) and slow (sequential).
Insight: What if we could have RNN-like O(n) inference with transformer-like parallel training?
19.3 State-Space Models Fundamentals
19.3.1 The Continuous-Time View
SSMs come from control theory. A linear time-invariant system:
\[ \begin{aligned} h'(t) &= Ah(t) + Bx(t) \\ y(t) &= Ch(t) + Dx(t) \end{aligned} \]
Where: - \(x(t)\) is the input signal - \(h(t)\) is the hidden state - \(y(t)\) is the output - \(A, B, C, D\) are learnable matrices
19.3.2 Discretization
For discrete sequences, we discretize with step size \(\Delta\):
\[ \begin{aligned} \bar{A} &= \exp(\Delta A) \\ \bar{B} &= (\Delta A)^{-1}(\exp(\Delta A) - I) \cdot \Delta B \\ h_t &= \bar{A} h_{t-1} + \bar{B} x_t \\ y_t &= C h_t + D x_t \end{aligned} \]
This is now a recurrence—can be computed sequentially like an RNN.
19.3.3 The Parallel Training Trick
The recurrence can be unrolled as a convolution:
\[ y = x * \bar{K} \]
where \(\bar{K}\) is the SSM kernel:
\[ \bar{K} = (C\bar{B}, C\bar{A}\bar{B}, C\bar{A}^2\bar{B}, \ldots, C\bar{A}^{n-1}\bar{B}) \]
Key insight: During training, compute via FFT convolution. During inference, compute via recurrence.
def ssm_forward(x, A, B, C, D, delta):
"""
SSM forward pass.
Training: Use convolution (parallel)
Inference: Use recurrence (sequential but O(1) per step)
"""
# Discretize
A_bar = discretize_A(A, delta)
B_bar = discretize_B(A, B, delta)
if training:
# Compute kernel
K = compute_kernel(A_bar, B_bar, C, length=x.shape[-1])
# Convolve (parallel via FFT)
y = fft_conv(x, K)
else:
# Recurrence (sequential)
h = torch.zeros(batch, hidden_dim)
y = []
for t in range(x.shape[-1]):
h = A_bar @ h + B_bar @ x[:, t]
y.append(C @ h + D @ x[:, t])
y = torch.stack(y, dim=-1)
return y
def compute_kernel(A_bar, B_bar, C, length):
"""Compute SSM convolution kernel."""
K = []
A_power = torch.eye(A_bar.shape[0])
for _ in range(length):
K.append(C @ A_power @ B_bar)
A_power = A_power @ A_bar
return torch.stack(K)19.4 S4: Structured State Spaces
19.4.1 The HiPPO Matrix
S4 (Structured State Space) uses a special initialization for \(A\) called HiPPO:
\[ A_{nk} = -\begin{cases} (2n+1)^{1/2}(2k+1)^{1/2} & \text{if } n > k \\ n+1 & \text{if } n = k \\ 0 & \text{if } n < k \end{cases} \]
Why HiPPO? It optimally compresses history into fixed-size state through Legendre polynomial projections.
def make_hippo_matrix(N):
"""
Create HiPPO matrix for optimal history compression.
"""
A = torch.zeros(N, N)
for n in range(N):
for k in range(N):
if n > k:
A[n, k] = -((2*n + 1) ** 0.5) * ((2*k + 1) ** 0.5)
elif n == k:
A[n, k] = -(n + 1)
return A19.4.2 Diagonal Approximation
Full matrix \(A\) is expensive. S4 diagonalizes:
\[ A = V \Lambda V^{-1} \]
Then the recurrence becomes element-wise:
\[ h_t = \Lambda \odot h_{t-1} + \tilde{B} x_t \]
Complexity: O(n) instead of O(nd²) per layer.
19.5 Mamba: Selective State Spaces
19.5.1 The Selectivity Problem
Standard SSMs have fixed \(A, B, C\)—they can’t adapt to input content.
Attention advantage: Query-key matching is content-dependent.
Mamba insight: Make SSM parameters input-dependent.
19.5.2 Input-Dependent Parameters
class SelectiveSSM(nn.Module):
"""
Mamba's selective state space layer.
Key innovation: B, C, and delta depend on input.
"""
def __init__(self, d_model, d_state):
super().__init__()
self.d_model = d_model
self.d_state = d_state
# A is still fixed (for stability)
self.A_log = nn.Parameter(torch.randn(d_state))
# B, C, delta are input-dependent
self.B_proj = nn.Linear(d_model, d_state)
self.C_proj = nn.Linear(d_model, d_state)
self.delta_proj = nn.Linear(d_model, 1)
self.D = nn.Parameter(torch.ones(d_model))
def forward(self, x):
batch, seq_len, d = x.shape
# Input-dependent parameters
B = self.B_proj(x) # [batch, seq, d_state]
C = self.C_proj(x) # [batch, seq, d_state]
delta = F.softplus(self.delta_proj(x)) # [batch, seq, 1]
# Fixed A (log-parameterized for stability)
A = -torch.exp(self.A_log) # [d_state]
# Discretize with input-dependent delta
A_bar = torch.exp(delta * A) # [batch, seq, d_state]
B_bar = delta * B # Simplified discretization
# Selective scan (parallel algorithm)
y = selective_scan(x, A_bar, B_bar, C, self.D)
return y19.5.3 The Selective Scan
With input-dependent parameters, we can’t use FFT convolution. Mamba uses a custom parallel scan:
def selective_scan(x, A, B, C, D):
"""
Parallel selective scan algorithm.
Uses associative scan for O(n) parallel complexity.
"""
batch, seq_len, d_model = x.shape
d_state = A.shape[-1]
# Initialize
h = torch.zeros(batch, d_state, device=x.device)
outputs = []
# Sequential for illustration (actual impl uses parallel scan)
for t in range(seq_len):
# State update: h = A[t] * h + B[t] * x[t]
h = A[:, t] * h + B[:, t] * x[:, t, :, None]
# Output: y = C[t] * h + D * x[t]
y = (C[:, t] * h).sum(-1) + D * x[:, t]
outputs.append(y)
return torch.stack(outputs, dim=1)Parallel scan: The recurrence \(h_t = a_t h_{t-1} + b_t\) is associative and can be parallelized:
def parallel_scan(a, b):
"""
Parallel prefix scan for linear recurrence.
h[t] = a[t] * h[t-1] + b[t]
Uses O(log n) parallel steps instead of O(n) sequential.
"""
# Combine pairs
# (a1, b1) ⊕ (a2, b2) = (a1*a2, a2*b1 + b2)
# This is associative, enabling parallel reduction
pass # Actual implementation uses custom CUDA kernels19.6 Mamba Architecture
19.6.1 The Full Block
class MambaBlock(nn.Module):
"""
Full Mamba block with gating.
"""
def __init__(self, d_model, d_state=16, expand=2):
super().__init__()
d_inner = d_model * expand
# Input projection
self.in_proj = nn.Linear(d_model, d_inner * 2)
# Conv for local context
self.conv1d = nn.Conv1d(
d_inner, d_inner,
kernel_size=4,
groups=d_inner,
padding=3
)
# Selective SSM
self.ssm = SelectiveSSM(d_inner, d_state)
# Output projection
self.out_proj = nn.Linear(d_inner, d_model)
def forward(self, x):
# x: [batch, seq, d_model]
# Split into two paths
xz = self.in_proj(x)
x, z = xz.chunk(2, dim=-1)
# Conv + SSM path
x = x.transpose(1, 2) # [batch, d_inner, seq]
x = self.conv1d(x)[:, :, :x.shape[-1]]
x = x.transpose(1, 2)
x = F.silu(x)
x = self.ssm(x)
# Gating
z = F.silu(z)
x = x * z
# Output
return self.out_proj(x)19.6.2 Mamba vs Transformer Comparison
Mamba-3B vs Transformer-3B:
Mamba Transformer
──────────────────────────────────────────────
Inference (1K ctx) 1.0x 1.0x
Inference (8K ctx) 1.0x 2.5x (KV cache)
Inference (64K ctx) 1.0x 10x+ (KV cache explosion)
Training (1K) 0.9x 1.0x
Training (8K) 0.7x 1.0x
Training (64K) 0.5x 3.0x (memory bound)
Quality (avg benchmarks): ~95% of equivalent transformer
19.7 RWKV: RNN Reinvented
19.7.1 The RWKV Approach
RWKV (Receptance Weighted Key Value) is another linear-time architecture:
class RWKVBlock(nn.Module):
"""
RWKV time-mixing block.
Linear attention with learned decay.
"""
def __init__(self, d_model):
super().__init__()
# Learnable decay
self.time_decay = nn.Parameter(torch.randn(d_model))
self.time_first = nn.Parameter(torch.randn(d_model))
# Projections
self.key = nn.Linear(d_model, d_model)
self.value = nn.Linear(d_model, d_model)
self.receptance = nn.Linear(d_model, d_model)
self.output = nn.Linear(d_model, d_model)
def forward(self, x, state=None):
batch, seq_len, d = x.shape
# Time mixing with learned interpolation
if state is None:
state = torch.zeros(batch, d, device=x.device)
k = self.key(x)
v = self.value(x)
r = torch.sigmoid(self.receptance(x))
outputs = []
for t in range(seq_len):
# WKV computation with decay
wkv = self.time_first * k[:, t] * v[:, t] + state
state = torch.exp(-torch.exp(self.time_decay)) * state + k[:, t] * v[:, t]
# Gated output
out = r[:, t] * wkv
outputs.append(out)
return self.output(torch.stack(outputs, dim=1)), state19.7.2 RWKV Advantages
- True RNN inference: O(1) memory per token
- Parallel training: Can be formulated as convolution
- Competitive quality: Matches transformers up to 14B scale
19.8 Hybrid Architectures
19.8.1 The Best of Both Worlds
Pure SSMs sometimes underperform on tasks requiring precise retrieval. Solution: combine SSM and attention.
class JambaBlock(nn.Module):
"""
Jamba-style hybrid: SSM + Attention + MoE.
Pattern: 7 Mamba blocks, then 1 Attention block.
"""
def __init__(self, d_model, use_attention=False):
super().__init__()
if use_attention:
self.mixer = MultiHeadAttention(d_model)
else:
self.mixer = MambaBlock(d_model)
self.norm = RMSNorm(d_model)
self.ffn = MoEFFN(d_model) # Can also be dense
def forward(self, x):
x = x + self.mixer(self.norm(x))
x = x + self.ffn(self.norm(x))
return x
class JambaModel(nn.Module):
def __init__(self, d_model, num_layers=32):
super().__init__()
self.layers = nn.ModuleList()
for i in range(num_layers):
# Every 8th layer is attention
use_attention = (i % 8 == 7)
self.layers.append(JambaBlock(d_model, use_attention))19.8.2 StripedHyena
Another hybrid approach: interleaved SSM and attention at finer granularity.
class StripedHyenaBlock(nn.Module):
"""
Alternating Hyena (SSM-like) and Attention.
"""
def __init__(self, d_model, layer_idx):
super().__init__()
# Alternate between Hyena and Attention
if layer_idx % 2 == 0:
self.mixer = HyenaOperator(d_model)
else:
self.mixer = MultiHeadAttention(d_model)19.9 When to Use SSMs
19.9.1 SSM Advantages
- O(n) complexity: Linear scaling with sequence length
- O(1) inference memory: Fixed state size, no KV cache
- Long sequences: Naturally handles 100K+ tokens
- Continuous signals: Better for time series, audio, video
19.9.2 SSM Limitations
- Retrieval tasks: Attention excels at “needle in haystack”
- Copying/recall: SSMs struggle with exact reproduction
- Maturity: Less tooling, fewer optimized kernels
- Quality ceiling: May not match best transformers (yet)
19.9.3 Decision Framework
Use SSM when:
✓ Very long sequences (>64K tokens)
✓ Streaming/real-time inference
✓ Memory-constrained deployment
✓ Continuous signals (audio, sensors)
✓ Latency matters more than peak quality
Use Transformer when:
✓ Retrieval-heavy tasks (RAG, QA)
✓ Maximum quality required
✓ Moderate sequence lengths (<32K)
✓ Existing infrastructure/tooling
Use Hybrid when:
✓ Long context + some retrieval needed
✓ Want SSM efficiency with attention accuracy
✓ Production systems where both matter
19.10 Implementation Tips
19.10.1 Efficient Mamba Kernels
# Use the official Mamba CUDA kernels
from mamba_ssm import Mamba
model = Mamba(
d_model=2048,
d_state=16,
d_conv=4,
expand=2
)
# Or use the pure PyTorch fallback (slower but portable)
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel19.10.2 State Caching for Inference
class MambaCache:
"""
Cache SSM state for efficient generation.
"""
def __init__(self, model, batch_size):
self.states = []
for layer in model.layers:
# Each layer has fixed-size state
state = torch.zeros(batch_size, layer.d_state)
self.states.append(state)
def update(self, layer_idx, new_state):
self.states[layer_idx] = new_state
def get(self, layer_idx):
return self.states[layer_idx]19.11 Property Audit
| Property | Applies? | How it’s exploited |
|---|---|---|
| Associativity | Primary | The linear recurrence \((A_t, b_t) \oplus (A_{t+1}, b_{t+1}) = (A_{t+1} A_t, A_{t+1} b_t + b_{t+1})\) is associative, enabling parallel scan |
| Separability | Secondary | The SSM kernel (convolution view) factors time and state dimensions; low state dimension d_state is a separability choice |
| Locality | Yes | Chunked scan keeps working state in SRAM |
| Sparsity | No | Dense state transitions |
| Redundancy | Partial | State caching during generation (analogous to KV cache) |
| Symmetry | No | Mamba’s selective mechanism breaks time-invariance |
Dominant property: Associativity — the parallel scan transforms O(N) sequential steps into O(log N) parallel steps, which is the fundamental algorithmic advance.
19.12 Key Takeaways
SSMs are O(n): Linear complexity via state-space formulation.
Training vs inference: Convolution for parallel training, recurrence for efficient inference.
Mamba = selective SSM: Input-dependent parameters enable content-aware processing.
Hybrids may be optimal: Combining SSM efficiency with attention precision.
Different strengths: SSMs for length/efficiency, attention for retrieval/quality.
Rapidly evolving: Mamba-2, RWKV-6, and new architectures emerging regularly.
The accompanying notebook lets you:
- Implement a basic SSM from scratch
- Compare SSM vs attention complexity empirically
- Experiment with Mamba on simple tasks
- Build a hybrid architecture
Notebook support for this chapter is in progress. For now, run the examples locally and compare SSM vs attention behavior on your hardware.
19.13 Further Reading
- Gu et al. (2022). “Efficiently Modeling Long Sequences with Structured State Spaces” (S4)
- Gu & Dao (2023). “Mamba: Linear-Time Sequence Modeling with Selective State Spaces”
- Peng et al. (2023). “RWKV: Reinventing RNNs for the Transformer Era”
- Lieber et al. (2024). “Jamba: A Hybrid Transformer-Mamba Language Model”
- Dao & Gu (2024). “Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality” (Mamba-2)
These problems have no provided solutions. They require applying the framework to new algorithms.
Challenge 1: Derive Ring Attention Given: multiple GPUs, each holding a shard of the KV cache. The full attention matrix doesn’t fit on any single GPU. Using the property audit approach, identify which properties enable distributed attention computation. Design an algorithm (hint: it involves a ring topology for KV passing). What’s the communication cost?
Challenge 2: Speculative Decoding as Property Exploitation Speculative decoding uses a draft model to generate K tokens, then verifies with the target model. Frame this optimization in terms of the six properties. Which properties explain why it works? Why does the acceptance rate determine the speedup?
Challenge 3: Beyond Attention Graph neural networks (GNNs) aggregate neighbor features: \(h_i' = \text{AGG}(\{h_j : j \in \mathcal{N}(i)\})\). Which properties does this operation have? How would you optimize it for power-law degree distributions (where a few nodes have millions of neighbors)?