19 State-Space Models
Linear-Time Alternatives to Attention
Attention is O(n²). For a million tokens, that’s a trillion operations per layer.
State-space models achieve O(n) complexity. What’s the catch?
19.1 The Attention Problem
Transformers dominate language modeling, but attention has fundamental scaling issues:
Attention complexity per layer:
Sequence Memory FLOPs
───────────────────────────────
1K 4 MB 2 billion
4K 64 MB 32 billion
16K 1 GB 512 billion
64K 16 GB 8 trillion
256K 256 GB 128 trillion
Even with FlashAttention (reduces memory), compute is still O(n²).
State-space models (SSMs) promise O(n) complexity with competitive quality.
19.2 The RNN Connection
RNNs are naturally O(n)—each token updates a fixed-size hidden state:
def rnn_forward(x, h0, W_h, W_x):
"""
Simple RNN: O(n) in sequence length.
"""
h = h0
outputs = []
for t in range(len(x)):
h = tanh(W_h @ h + W_x @ x[t]) # O(d²) per step
outputs.append(h)
return outputs # Total: O(n * d²)Problem: RNNs are hard to train (vanishing gradients) and slow (sequential).
Insight: What if we could have RNN-like O(n) inference with transformer-like parallel training?
19.3 State-Space Models Fundamentals
19.3.1 The Continuous-Time View
SSMs come from control theory. A linear time-invariant system:
\[ \begin{aligned} h'(t) &= Ah(t) + Bx(t) \\ y(t) &= Ch(t) + Dx(t) \end{aligned} \]
Where: - \(x(t)\) is the input signal - \(h(t)\) is the hidden state - \(y(t)\) is the output - \(A, B, C, D\) are learnable matrices
19.3.2 Discretization
For discrete sequences, we discretize with step size \(\Delta\):
\[ \begin{aligned} \bar{A} &= \exp(\Delta A) \\ \bar{B} &= (\Delta A)^{-1}(\exp(\Delta A) - I) \cdot \Delta B \\ h_t &= \bar{A} h_{t-1} + \bar{B} x_t \\ y_t &= C h_t + D x_t \end{aligned} \]
This is now a recurrence—can be computed sequentially like an RNN.
19.3.3 The Parallel Training Trick
The recurrence can be unrolled as a convolution:
\[ y = x * \bar{K} \]
where \(\bar{K}\) is the SSM kernel:
\[ \bar{K} = (C\bar{B}, C\bar{A}\bar{B}, C\bar{A}^2\bar{B}, \ldots, C\bar{A}^{n-1}\bar{B}) \]
Key insight: During training, compute via FFT convolution. During inference, compute via recurrence.
def ssm_forward(x, A, B, C, D, delta):
"""
SSM forward pass.
Training: Use convolution (parallel)
Inference: Use recurrence (sequential but O(1) per step)
"""
# Discretize
A_bar = discretize_A(A, delta)
B_bar = discretize_B(A, B, delta)
if training:
# Compute kernel
K = compute_kernel(A_bar, B_bar, C, length=x.shape[-1])
# Convolve (parallel via FFT)
y = fft_conv(x, K)
else:
# Recurrence (sequential)
h = torch.zeros(batch, hidden_dim)
y = []
for t in range(x.shape[-1]):
h = A_bar @ h + B_bar @ x[:, t]
y.append(C @ h + D @ x[:, t])
y = torch.stack(y, dim=-1)
return y
def compute_kernel(A_bar, B_bar, C, length):
"""Compute SSM convolution kernel."""
K = []
A_power = torch.eye(A_bar.shape[0])
for _ in range(length):
K.append(C @ A_power @ B_bar)
A_power = A_power @ A_bar
return torch.stack(K)19.4 S4: Structured State Spaces
19.4.1 The HiPPO Matrix
S4 (Structured State Space) uses a special initialization for \(A\) called HiPPO:
\[ A_{nk} = -\begin{cases} (2n+1)^{1/2}(2k+1)^{1/2} & \text{if } n > k \\ n+1 & \text{if } n = k \\ 0 & \text{if } n < k \end{cases} \]
Why HiPPO? It optimally compresses history into fixed-size state through Legendre polynomial projections.
def make_hippo_matrix(N):
"""
Create HiPPO matrix for optimal history compression.
"""
A = torch.zeros(N, N)
for n in range(N):
for k in range(N):
if n > k:
A[n, k] = -((2*n + 1) ** 0.5) * ((2*k + 1) ** 0.5)
elif n == k:
A[n, k] = -(n + 1)
return A19.4.2 Diagonal Approximation
Full matrix \(A\) is expensive. S4 diagonalizes:
\[ A = V \Lambda V^{-1} \]
Then the recurrence becomes element-wise:
\[ h_t = \Lambda \odot h_{t-1} + \tilde{B} x_t \]
Complexity: O(n) instead of O(nd²) per layer.
19.5 Mamba: Selective State Spaces
19.5.1 The Selectivity Problem
Standard SSMs have fixed \(A, B, C\)—they can’t adapt to input content.
Attention advantage: Query-key matching is content-dependent.
Mamba insight: Make SSM parameters input-dependent.
19.5.2 Input-Dependent Parameters
class SelectiveSSM(nn.Module):
"""
Mamba's selective state space layer.
Key innovation: B, C, and delta depend on input.
"""
def __init__(self, d_model, d_state):
super().__init__()
self.d_model = d_model
self.d_state = d_state
# A is still fixed (for stability)
self.A_log = nn.Parameter(torch.randn(d_state))
# B, C, delta are input-dependent
self.B_proj = nn.Linear(d_model, d_state)
self.C_proj = nn.Linear(d_model, d_state)
self.delta_proj = nn.Linear(d_model, 1)
self.D = nn.Parameter(torch.ones(d_model))
def forward(self, x):
batch, seq_len, d = x.shape
# Input-dependent parameters
B = self.B_proj(x) # [batch, seq, d_state]
C = self.C_proj(x) # [batch, seq, d_state]
delta = F.softplus(self.delta_proj(x)) # [batch, seq, 1]
# Fixed A (log-parameterized for stability)
A = -torch.exp(self.A_log) # [d_state]
# Discretize with input-dependent delta
A_bar = torch.exp(delta * A) # [batch, seq, d_state]
B_bar = delta * B # Simplified discretization
# Selective scan (parallel algorithm)
y = selective_scan(x, A_bar, B_bar, C, self.D)
return y19.5.3 The Selective Scan
With input-dependent parameters, we can’t use FFT convolution. Mamba uses a custom parallel scan:
def selective_scan(x, A, B, C, D):
"""
Parallel selective scan algorithm.
Uses associative scan for O(n) parallel complexity.
"""
batch, seq_len, d_model = x.shape
d_state = A.shape[-1]
# Initialize
h = torch.zeros(batch, d_state, device=x.device)
outputs = []
# Sequential for illustration (actual impl uses parallel scan)
for t in range(seq_len):
# State update: h = A[t] * h + B[t] * x[t]
h = A[:, t] * h + B[:, t] * x[:, t, :, None]
# Output: y = C[t] * h + D * x[t]
y = (C[:, t] * h).sum(-1) + D * x[:, t]
outputs.append(y)
return torch.stack(outputs, dim=1)Parallel scan: The recurrence \(h_t = a_t h_{t-1} + b_t\) is associative and can be parallelized:
def parallel_scan(a, b):
"""
Parallel prefix scan for linear recurrence.
h[t] = a[t] * h[t-1] + b[t]
Uses O(log n) parallel steps instead of O(n) sequential.
"""
# Combine pairs
# (a1, b1) ⊕ (a2, b2) = (a1*a2, a2*b1 + b2)
# This is associative, enabling parallel reduction
pass # Actual implementation uses custom CUDA kernels19.6 Mamba Architecture
19.6.1 The Full Block
class MambaBlock(nn.Module):
"""
Full Mamba block with gating.
"""
def __init__(self, d_model, d_state=16, expand=2):
super().__init__()
d_inner = d_model * expand
# Input projection
self.in_proj = nn.Linear(d_model, d_inner * 2)
# Conv for local context
self.conv1d = nn.Conv1d(
d_inner, d_inner,
kernel_size=4,
groups=d_inner,
padding=3
)
# Selective SSM
self.ssm = SelectiveSSM(d_inner, d_state)
# Output projection
self.out_proj = nn.Linear(d_inner, d_model)
def forward(self, x):
# x: [batch, seq, d_model]
# Split into two paths
xz = self.in_proj(x)
x, z = xz.chunk(2, dim=-1)
# Conv + SSM path
x = x.transpose(1, 2) # [batch, d_inner, seq]
x = self.conv1d(x)[:, :, :x.shape[-1]]
x = x.transpose(1, 2)
x = F.silu(x)
x = self.ssm(x)
# Gating
z = F.silu(z)
x = x * z
# Output
return self.out_proj(x)19.6.2 Mamba vs Transformer Comparison
Mamba-3B vs Transformer-3B:
Mamba Transformer
──────────────────────────────────────────────
Inference (1K ctx) 1.0x 1.0x
Inference (8K ctx) 1.0x 2.5x (KV cache)
Inference (64K ctx) 1.0x 10x+ (KV cache explosion)
Training (1K) 0.9x 1.0x
Training (8K) 0.7x 1.0x
Training (64K) 0.5x 3.0x (memory bound)
Quality (avg benchmarks): ~95% of equivalent transformer
19.7 RWKV: RNN Reinvented
19.7.1 The RWKV Approach
RWKV (Receptance Weighted Key Value) is another linear-time architecture:
class RWKVBlock(nn.Module):
"""
RWKV time-mixing block.
Linear attention with learned decay.
"""
def __init__(self, d_model):
super().__init__()
# Learnable decay
self.time_decay = nn.Parameter(torch.randn(d_model))
self.time_first = nn.Parameter(torch.randn(d_model))
# Projections
self.key = nn.Linear(d_model, d_model)
self.value = nn.Linear(d_model, d_model)
self.receptance = nn.Linear(d_model, d_model)
self.output = nn.Linear(d_model, d_model)
def forward(self, x, state=None):
batch, seq_len, d = x.shape
# Time mixing with learned interpolation
if state is None:
state = torch.zeros(batch, d, device=x.device)
k = self.key(x)
v = self.value(x)
r = torch.sigmoid(self.receptance(x))
outputs = []
for t in range(seq_len):
# WKV computation with decay
wkv = self.time_first * k[:, t] * v[:, t] + state
state = torch.exp(-torch.exp(self.time_decay)) * state + k[:, t] * v[:, t]
# Gated output
out = r[:, t] * wkv
outputs.append(out)
return self.output(torch.stack(outputs, dim=1)), state19.7.2 RWKV Advantages
- True RNN inference: O(1) memory per token
- Parallel training: Can be formulated as convolution
- Competitive quality: Matches transformers up to 14B scale
19.8 Hybrid Architectures
19.8.1 The Best of Both Worlds
Pure SSMs sometimes underperform on tasks requiring precise retrieval. Solution: combine SSM and attention.
class JambaBlock(nn.Module):
"""
Jamba-style hybrid: SSM + Attention + MoE.
Pattern: 7 Mamba blocks, then 1 Attention block.
"""
def __init__(self, d_model, use_attention=False):
super().__init__()
if use_attention:
self.mixer = MultiHeadAttention(d_model)
else:
self.mixer = MambaBlock(d_model)
self.norm = RMSNorm(d_model)
self.ffn = MoEFFN(d_model) # Can also be dense
def forward(self, x):
x = x + self.mixer(self.norm(x))
x = x + self.ffn(self.norm(x))
return x
class JambaModel(nn.Module):
def __init__(self, d_model, num_layers=32):
super().__init__()
self.layers = nn.ModuleList()
for i in range(num_layers):
# Every 8th layer is attention
use_attention = (i % 8 == 7)
self.layers.append(JambaBlock(d_model, use_attention))19.8.2 StripedHyena
Another hybrid approach: interleaved SSM and attention at finer granularity.
class StripedHyenaBlock(nn.Module):
"""
Alternating Hyena (SSM-like) and Attention.
"""
def __init__(self, d_model, layer_idx):
super().__init__()
# Alternate between Hyena and Attention
if layer_idx % 2 == 0:
self.mixer = HyenaOperator(d_model)
else:
self.mixer = MultiHeadAttention(d_model)19.9 When to Use SSMs
19.9.1 SSM Advantages
- O(n) complexity: Linear scaling with sequence length
- O(1) inference memory: Fixed state size, no KV cache
- Long sequences: Naturally handles 100K+ tokens
- Continuous signals: Better for time series, audio, video
19.9.2 SSM Limitations
- Retrieval tasks: Attention excels at “needle in haystack”
- Copying/recall: SSMs struggle with exact reproduction
- Maturity: Less tooling, fewer optimized kernels
- Quality ceiling: May not match best transformers (yet)
19.9.3 Decision Framework
Use SSM when:
✓ Very long sequences (>64K tokens)
✓ Streaming/real-time inference
✓ Memory-constrained deployment
✓ Continuous signals (audio, sensors)
✓ Latency matters more than peak quality
Use Transformer when:
✓ Retrieval-heavy tasks (RAG, QA)
✓ Maximum quality required
✓ Moderate sequence lengths (<32K)
✓ Existing infrastructure/tooling
Use Hybrid when:
✓ Long context + some retrieval needed
✓ Want SSM efficiency with attention accuracy
✓ Production systems where both matter
19.10 Implementation Tips
19.10.1 Efficient Mamba Kernels
# Use the official Mamba CUDA kernels
from mamba_ssm import Mamba
model = Mamba(
d_model=2048,
d_state=16,
d_conv=4,
expand=2
)
# Or use the pure PyTorch fallback (slower but portable)
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel19.10.2 State Caching for Inference
class MambaCache:
"""
Cache SSM state for efficient generation.
"""
def __init__(self, model, batch_size):
self.states = []
for layer in model.layers:
# Each layer has fixed-size state
state = torch.zeros(batch_size, layer.d_state)
self.states.append(state)
def update(self, layer_idx, new_state):
self.states[layer_idx] = new_state
def get(self, layer_idx):
return self.states[layer_idx]19.11 Key Takeaways
SSMs are O(n): Linear complexity via state-space formulation.
Training vs inference: Convolution for parallel training, recurrence for efficient inference.
Mamba = selective SSM: Input-dependent parameters enable content-aware processing.
Hybrids may be optimal: Combining SSM efficiency with attention precision.
Different strengths: SSMs for length/efficiency, attention for retrieval/quality.
Rapidly evolving: Mamba-2, RWKV-6, and new architectures emerging regularly.
19.12 Further Reading
- Gu et al. (2022). “Efficiently Modeling Long Sequences with Structured State Spaces” (S4)
- Gu & Dao (2023). “Mamba: Linear-Time Sequence Modeling with Selective State Spaces”
- Peng et al. (2023). “RWKV: Reinventing RNNs for the Transformer Era”
- Lieber et al. (2024). “Jamba: A Hybrid Transformer-Mamba Language Model”
- Dao & Gu (2024). “Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality” (Mamba-2)