19  Investigation: State-Space Models

Linear-Time Alternatives to Attention

Attention is O(n²). For a million tokens, that’s a trillion operations per layer.

State-space models achieve O(n) complexity. What’s the catch?

19.1 The Attention Problem

Transformers dominate language modeling, but attention has fundamental scaling issues:

Attention complexity per layer:

Sequence    Memory      FLOPs
───────────────────────────────
1K          4 MB        2 billion
4K          64 MB       32 billion
16K         1 GB        512 billion
64K         16 GB       8 trillion
256K        256 GB      128 trillion

Even with FlashAttention (reduces memory), compute is still O(n²).

State-space models (SSMs) promise O(n) complexity with competitive quality.

19.2 The RNN Connection

RNNs are naturally O(n)—each token updates a fixed-size hidden state:

def rnn_forward(x, h0, W_h, W_x):
    """
    Simple RNN: O(n) in sequence length.
    """
    h = h0
    outputs = []

    for t in range(len(x)):
        h = tanh(W_h @ h + W_x @ x[t])  # O(d²) per step
        outputs.append(h)

    return outputs  # Total: O(n * d²)

Problem: RNNs are hard to train (vanishing gradients) and slow (sequential).

Insight: What if we could have RNN-like O(n) inference with transformer-like parallel training?

19.3 State-Space Models Fundamentals

19.3.1 The Continuous-Time View

SSMs come from control theory. A linear time-invariant system:

\[ \begin{aligned} h'(t) &= Ah(t) + Bx(t) \\ y(t) &= Ch(t) + Dx(t) \end{aligned} \]

Where: - \(x(t)\) is the input signal - \(h(t)\) is the hidden state - \(y(t)\) is the output - \(A, B, C, D\) are learnable matrices

19.3.2 Discretization

For discrete sequences, we discretize with step size \(\Delta\):

\[ \begin{aligned} \bar{A} &= \exp(\Delta A) \\ \bar{B} &= (\Delta A)^{-1}(\exp(\Delta A) - I) \cdot \Delta B \\ h_t &= \bar{A} h_{t-1} + \bar{B} x_t \\ y_t &= C h_t + D x_t \end{aligned} \]

This is now a recurrence—can be computed sequentially like an RNN.

19.3.3 The Parallel Training Trick

The recurrence can be unrolled as a convolution:

\[ y = x * \bar{K} \]

where \(\bar{K}\) is the SSM kernel:

\[ \bar{K} = (C\bar{B}, C\bar{A}\bar{B}, C\bar{A}^2\bar{B}, \ldots, C\bar{A}^{n-1}\bar{B}) \]

Key insight: During training, compute via FFT convolution. During inference, compute via recurrence.

def ssm_forward(x, A, B, C, D, delta):
    """
    SSM forward pass.

    Training: Use convolution (parallel)
    Inference: Use recurrence (sequential but O(1) per step)
    """
    # Discretize
    A_bar = discretize_A(A, delta)
    B_bar = discretize_B(A, B, delta)

    if training:
        # Compute kernel
        K = compute_kernel(A_bar, B_bar, C, length=x.shape[-1])
        # Convolve (parallel via FFT)
        y = fft_conv(x, K)
    else:
        # Recurrence (sequential)
        h = torch.zeros(batch, hidden_dim)
        y = []
        for t in range(x.shape[-1]):
            h = A_bar @ h + B_bar @ x[:, t]
            y.append(C @ h + D @ x[:, t])
        y = torch.stack(y, dim=-1)

    return y

def compute_kernel(A_bar, B_bar, C, length):
    """Compute SSM convolution kernel."""
    K = []
    A_power = torch.eye(A_bar.shape[0])

    for _ in range(length):
        K.append(C @ A_power @ B_bar)
        A_power = A_power @ A_bar

    return torch.stack(K)

19.4 S4: Structured State Spaces

19.4.1 The HiPPO Matrix

S4 (Structured State Space) uses a special initialization for \(A\) called HiPPO:

\[ A_{nk} = -\begin{cases} (2n+1)^{1/2}(2k+1)^{1/2} & \text{if } n > k \\ n+1 & \text{if } n = k \\ 0 & \text{if } n < k \end{cases} \]

Why HiPPO? It optimally compresses history into fixed-size state through Legendre polynomial projections.

def make_hippo_matrix(N):
    """
    Create HiPPO matrix for optimal history compression.
    """
    A = torch.zeros(N, N)

    for n in range(N):
        for k in range(N):
            if n > k:
                A[n, k] = -((2*n + 1) ** 0.5) * ((2*k + 1) ** 0.5)
            elif n == k:
                A[n, k] = -(n + 1)

    return A

19.4.2 Diagonal Approximation

Full matrix \(A\) is expensive. S4 diagonalizes:

\[ A = V \Lambda V^{-1} \]

Then the recurrence becomes element-wise:

\[ h_t = \Lambda \odot h_{t-1} + \tilde{B} x_t \]

Complexity: O(n) instead of O(nd²) per layer.

19.5 Mamba: Selective State Spaces

19.5.1 The Selectivity Problem

Standard SSMs have fixed \(A, B, C\)—they can’t adapt to input content.

Attention advantage: Query-key matching is content-dependent.

Mamba insight: Make SSM parameters input-dependent.

19.5.2 Input-Dependent Parameters

class SelectiveSSM(nn.Module):
    """
    Mamba's selective state space layer.

    Key innovation: B, C, and delta depend on input.
    """
    def __init__(self, d_model, d_state):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state

        # A is still fixed (for stability)
        self.A_log = nn.Parameter(torch.randn(d_state))

        # B, C, delta are input-dependent
        self.B_proj = nn.Linear(d_model, d_state)
        self.C_proj = nn.Linear(d_model, d_state)
        self.delta_proj = nn.Linear(d_model, 1)

        self.D = nn.Parameter(torch.ones(d_model))

    def forward(self, x):
        batch, seq_len, d = x.shape

        # Input-dependent parameters
        B = self.B_proj(x)  # [batch, seq, d_state]
        C = self.C_proj(x)  # [batch, seq, d_state]
        delta = F.softplus(self.delta_proj(x))  # [batch, seq, 1]

        # Fixed A (log-parameterized for stability)
        A = -torch.exp(self.A_log)  # [d_state]

        # Discretize with input-dependent delta
        A_bar = torch.exp(delta * A)  # [batch, seq, d_state]
        B_bar = delta * B  # Simplified discretization

        # Selective scan (parallel algorithm)
        y = selective_scan(x, A_bar, B_bar, C, self.D)

        return y

19.5.3 The Selective Scan

With input-dependent parameters, we can’t use FFT convolution. Mamba uses a custom parallel scan:

def selective_scan(x, A, B, C, D):
    """
    Parallel selective scan algorithm.

    Uses associative scan for O(n) parallel complexity.
    """
    batch, seq_len, d_model = x.shape
    d_state = A.shape[-1]

    # Initialize
    h = torch.zeros(batch, d_state, device=x.device)
    outputs = []

    # Sequential for illustration (actual impl uses parallel scan)
    for t in range(seq_len):
        # State update: h = A[t] * h + B[t] * x[t]
        h = A[:, t] * h + B[:, t] * x[:, t, :, None]

        # Output: y = C[t] * h + D * x[t]
        y = (C[:, t] * h).sum(-1) + D * x[:, t]
        outputs.append(y)

    return torch.stack(outputs, dim=1)

Parallel scan: The recurrence \(h_t = a_t h_{t-1} + b_t\) is associative and can be parallelized:

def parallel_scan(a, b):
    """
    Parallel prefix scan for linear recurrence.

    h[t] = a[t] * h[t-1] + b[t]

    Uses O(log n) parallel steps instead of O(n) sequential.
    """
    # Combine pairs
    # (a1, b1) ⊕ (a2, b2) = (a1*a2, a2*b1 + b2)

    # This is associative, enabling parallel reduction
    pass  # Actual implementation uses custom CUDA kernels

19.6 Mamba Architecture

19.6.1 The Full Block

class MambaBlock(nn.Module):
    """
    Full Mamba block with gating.
    """
    def __init__(self, d_model, d_state=16, expand=2):
        super().__init__()
        d_inner = d_model * expand

        # Input projection
        self.in_proj = nn.Linear(d_model, d_inner * 2)

        # Conv for local context
        self.conv1d = nn.Conv1d(
            d_inner, d_inner,
            kernel_size=4,
            groups=d_inner,
            padding=3
        )

        # Selective SSM
        self.ssm = SelectiveSSM(d_inner, d_state)

        # Output projection
        self.out_proj = nn.Linear(d_inner, d_model)

    def forward(self, x):
        # x: [batch, seq, d_model]

        # Split into two paths
        xz = self.in_proj(x)
        x, z = xz.chunk(2, dim=-1)

        # Conv + SSM path
        x = x.transpose(1, 2)  # [batch, d_inner, seq]
        x = self.conv1d(x)[:, :, :x.shape[-1]]
        x = x.transpose(1, 2)
        x = F.silu(x)
        x = self.ssm(x)

        # Gating
        z = F.silu(z)
        x = x * z

        # Output
        return self.out_proj(x)

19.6.2 Mamba vs Transformer Comparison

Mamba-3B vs Transformer-3B:

                    Mamba       Transformer
──────────────────────────────────────────────
Inference (1K ctx)  1.0x        1.0x
Inference (8K ctx)  1.0x        2.5x (KV cache)
Inference (64K ctx) 1.0x        10x+ (KV cache explosion)

Training (1K)       0.9x        1.0x
Training (8K)       0.7x        1.0x
Training (64K)      0.5x        3.0x (memory bound)

Quality (avg benchmarks): ~95% of equivalent transformer

19.7 RWKV: RNN Reinvented

19.7.1 The RWKV Approach

RWKV (Receptance Weighted Key Value) is another linear-time architecture:

class RWKVBlock(nn.Module):
    """
    RWKV time-mixing block.

    Linear attention with learned decay.
    """
    def __init__(self, d_model):
        super().__init__()

        # Learnable decay
        self.time_decay = nn.Parameter(torch.randn(d_model))
        self.time_first = nn.Parameter(torch.randn(d_model))

        # Projections
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        self.receptance = nn.Linear(d_model, d_model)
        self.output = nn.Linear(d_model, d_model)

    def forward(self, x, state=None):
        batch, seq_len, d = x.shape

        # Time mixing with learned interpolation
        if state is None:
            state = torch.zeros(batch, d, device=x.device)

        k = self.key(x)
        v = self.value(x)
        r = torch.sigmoid(self.receptance(x))

        outputs = []
        for t in range(seq_len):
            # WKV computation with decay
            wkv = self.time_first * k[:, t] * v[:, t] + state
            state = torch.exp(-torch.exp(self.time_decay)) * state + k[:, t] * v[:, t]

            # Gated output
            out = r[:, t] * wkv
            outputs.append(out)

        return self.output(torch.stack(outputs, dim=1)), state

19.7.2 RWKV Advantages

  1. True RNN inference: O(1) memory per token
  2. Parallel training: Can be formulated as convolution
  3. Competitive quality: Matches transformers up to 14B scale

19.8 Hybrid Architectures

19.8.1 The Best of Both Worlds

Pure SSMs sometimes underperform on tasks requiring precise retrieval. Solution: combine SSM and attention.

class JambaBlock(nn.Module):
    """
    Jamba-style hybrid: SSM + Attention + MoE.

    Pattern: 7 Mamba blocks, then 1 Attention block.
    """
    def __init__(self, d_model, use_attention=False):
        super().__init__()

        if use_attention:
            self.mixer = MultiHeadAttention(d_model)
        else:
            self.mixer = MambaBlock(d_model)

        self.norm = RMSNorm(d_model)
        self.ffn = MoEFFN(d_model)  # Can also be dense

    def forward(self, x):
        x = x + self.mixer(self.norm(x))
        x = x + self.ffn(self.norm(x))
        return x

class JambaModel(nn.Module):
    def __init__(self, d_model, num_layers=32):
        super().__init__()

        self.layers = nn.ModuleList()
        for i in range(num_layers):
            # Every 8th layer is attention
            use_attention = (i % 8 == 7)
            self.layers.append(JambaBlock(d_model, use_attention))

19.8.2 StripedHyena

Another hybrid approach: interleaved SSM and attention at finer granularity.

class StripedHyenaBlock(nn.Module):
    """
    Alternating Hyena (SSM-like) and Attention.
    """
    def __init__(self, d_model, layer_idx):
        super().__init__()

        # Alternate between Hyena and Attention
        if layer_idx % 2 == 0:
            self.mixer = HyenaOperator(d_model)
        else:
            self.mixer = MultiHeadAttention(d_model)

19.9 When to Use SSMs

19.9.1 SSM Advantages

  1. O(n) complexity: Linear scaling with sequence length
  2. O(1) inference memory: Fixed state size, no KV cache
  3. Long sequences: Naturally handles 100K+ tokens
  4. Continuous signals: Better for time series, audio, video

19.9.2 SSM Limitations

  1. Retrieval tasks: Attention excels at “needle in haystack”
  2. Copying/recall: SSMs struggle with exact reproduction
  3. Maturity: Less tooling, fewer optimized kernels
  4. Quality ceiling: May not match best transformers (yet)

19.9.3 Decision Framework

Use SSM when:
  ✓ Very long sequences (>64K tokens)
  ✓ Streaming/real-time inference
  ✓ Memory-constrained deployment
  ✓ Continuous signals (audio, sensors)
  ✓ Latency matters more than peak quality

Use Transformer when:
  ✓ Retrieval-heavy tasks (RAG, QA)
  ✓ Maximum quality required
  ✓ Moderate sequence lengths (<32K)
  ✓ Existing infrastructure/tooling

Use Hybrid when:
  ✓ Long context + some retrieval needed
  ✓ Want SSM efficiency with attention accuracy
  ✓ Production systems where both matter

19.10 Implementation Tips

19.10.1 Efficient Mamba Kernels

# Use the official Mamba CUDA kernels
from mamba_ssm import Mamba

model = Mamba(
    d_model=2048,
    d_state=16,
    d_conv=4,
    expand=2
)

# Or use the pure PyTorch fallback (slower but portable)
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

19.10.2 State Caching for Inference

class MambaCache:
    """
    Cache SSM state for efficient generation.
    """
    def __init__(self, model, batch_size):
        self.states = []
        for layer in model.layers:
            # Each layer has fixed-size state
            state = torch.zeros(batch_size, layer.d_state)
            self.states.append(state)

    def update(self, layer_idx, new_state):
        self.states[layer_idx] = new_state

    def get(self, layer_idx):
        return self.states[layer_idx]

19.11 Property Audit

Property Applies? How it’s exploited
Associativity Primary The linear recurrence \((A_t, b_t) \oplus (A_{t+1}, b_{t+1}) = (A_{t+1} A_t, A_{t+1} b_t + b_{t+1})\) is associative, enabling parallel scan
Separability Secondary The SSM kernel (convolution view) factors time and state dimensions; low state dimension d_state is a separability choice
Locality Yes Chunked scan keeps working state in SRAM
Sparsity No Dense state transitions
Redundancy Partial State caching during generation (analogous to KV cache)
Symmetry No Mamba’s selective mechanism breaks time-invariance

Dominant property: Associativity — the parallel scan transforms O(N) sequential steps into O(log N) parallel steps, which is the fundamental algorithmic advance.

19.12 Key Takeaways

  1. SSMs are O(n): Linear complexity via state-space formulation.

  2. Training vs inference: Convolution for parallel training, recurrence for efficient inference.

  3. Mamba = selective SSM: Input-dependent parameters enable content-aware processing.

  4. Hybrids may be optimal: Combining SSM efficiency with attention precision.

  5. Different strengths: SSMs for length/efficiency, attention for retrieval/quality.

  6. Rapidly evolving: Mamba-2, RWKV-6, and new architectures emerging regularly.

NoteTry It Yourself

The accompanying notebook lets you:

  • Implement a basic SSM from scratch
  • Compare SSM vs attention complexity empirically
  • Experiment with Mamba on simple tasks
  • Build a hybrid architecture

Notebook support for this chapter is in progress. For now, run the examples locally and compare SSM vs attention behavior on your hardware.

19.13 Further Reading

  • Gu et al. (2022). “Efficiently Modeling Long Sequences with Structured State Spaces” (S4)
  • Gu & Dao (2023). “Mamba: Linear-Time Sequence Modeling with Selective State Spaces”
  • Peng et al. (2023). “RWKV: Reinventing RNNs for the Transformer Era”
  • Lieber et al. (2024). “Jamba: A Hybrid Transformer-Mamba Language Model”
  • Dao & Gu (2024). “Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality” (Mamba-2)

ImportantPart III Challenge Problems

These problems have no provided solutions. They require applying the framework to new algorithms.

Challenge 1: Derive Ring Attention Given: multiple GPUs, each holding a shard of the KV cache. The full attention matrix doesn’t fit on any single GPU. Using the property audit approach, identify which properties enable distributed attention computation. Design an algorithm (hint: it involves a ring topology for KV passing). What’s the communication cost?

Challenge 2: Speculative Decoding as Property Exploitation Speculative decoding uses a draft model to generate K tokens, then verifies with the target model. Frame this optimization in terms of the six properties. Which properties explain why it works? Why does the acceptance rate determine the speedup?

Challenge 3: Beyond Attention Graph neural networks (GNNs) aggregate neighbor features: \(h_i' = \text{AGG}(\{h_j : j \in \mathcal{N}(i)\})\). Which properties does this operation have? How would you optimize it for power-law degree distributions (where a few nodes have millions of neighbors)?