19  State-Space Models

Linear-Time Alternatives to Attention


Attention is O(n²). For a million tokens, that’s a trillion operations per layer.

State-space models achieve O(n) complexity. What’s the catch?

19.1 The Attention Problem

Transformers dominate language modeling, but attention has fundamental scaling issues:

Attention complexity per layer:

Sequence    Memory      FLOPs
───────────────────────────────
1K          4 MB        2 billion
4K          64 MB       32 billion
16K         1 GB        512 billion
64K         16 GB       8 trillion
256K        256 GB      128 trillion

Even with FlashAttention (reduces memory), compute is still O(n²).

State-space models (SSMs) promise O(n) complexity with competitive quality.

19.2 The RNN Connection

RNNs are naturally O(n)—each token updates a fixed-size hidden state:

def rnn_forward(x, h0, W_h, W_x):
    """
    Simple RNN: O(n) in sequence length.
    """
    h = h0
    outputs = []

    for t in range(len(x)):
        h = tanh(W_h @ h + W_x @ x[t])  # O(d²) per step
        outputs.append(h)

    return outputs  # Total: O(n * d²)

Problem: RNNs are hard to train (vanishing gradients) and slow (sequential).

Insight: What if we could have RNN-like O(n) inference with transformer-like parallel training?

19.3 State-Space Models Fundamentals

19.3.1 The Continuous-Time View

SSMs come from control theory. A linear time-invariant system:

\[ \begin{aligned} h'(t) &= Ah(t) + Bx(t) \\ y(t) &= Ch(t) + Dx(t) \end{aligned} \]

Where: - \(x(t)\) is the input signal - \(h(t)\) is the hidden state - \(y(t)\) is the output - \(A, B, C, D\) are learnable matrices

19.3.2 Discretization

For discrete sequences, we discretize with step size \(\Delta\):

\[ \begin{aligned} \bar{A} &= \exp(\Delta A) \\ \bar{B} &= (\Delta A)^{-1}(\exp(\Delta A) - I) \cdot \Delta B \\ h_t &= \bar{A} h_{t-1} + \bar{B} x_t \\ y_t &= C h_t + D x_t \end{aligned} \]

This is now a recurrence—can be computed sequentially like an RNN.

19.3.3 The Parallel Training Trick

The recurrence can be unrolled as a convolution:

\[ y = x * \bar{K} \]

where \(\bar{K}\) is the SSM kernel:

\[ \bar{K} = (C\bar{B}, C\bar{A}\bar{B}, C\bar{A}^2\bar{B}, \ldots, C\bar{A}^{n-1}\bar{B}) \]

Key insight: During training, compute via FFT convolution. During inference, compute via recurrence.

def ssm_forward(x, A, B, C, D, delta):
    """
    SSM forward pass.

    Training: Use convolution (parallel)
    Inference: Use recurrence (sequential but O(1) per step)
    """
    # Discretize
    A_bar = discretize_A(A, delta)
    B_bar = discretize_B(A, B, delta)

    if training:
        # Compute kernel
        K = compute_kernel(A_bar, B_bar, C, length=x.shape[-1])
        # Convolve (parallel via FFT)
        y = fft_conv(x, K)
    else:
        # Recurrence (sequential)
        h = torch.zeros(batch, hidden_dim)
        y = []
        for t in range(x.shape[-1]):
            h = A_bar @ h + B_bar @ x[:, t]
            y.append(C @ h + D @ x[:, t])
        y = torch.stack(y, dim=-1)

    return y

def compute_kernel(A_bar, B_bar, C, length):
    """Compute SSM convolution kernel."""
    K = []
    A_power = torch.eye(A_bar.shape[0])

    for _ in range(length):
        K.append(C @ A_power @ B_bar)
        A_power = A_power @ A_bar

    return torch.stack(K)

19.4 S4: Structured State Spaces

19.4.1 The HiPPO Matrix

S4 (Structured State Space) uses a special initialization for \(A\) called HiPPO:

\[ A_{nk} = -\begin{cases} (2n+1)^{1/2}(2k+1)^{1/2} & \text{if } n > k \\ n+1 & \text{if } n = k \\ 0 & \text{if } n < k \end{cases} \]

Why HiPPO? It optimally compresses history into fixed-size state through Legendre polynomial projections.

def make_hippo_matrix(N):
    """
    Create HiPPO matrix for optimal history compression.
    """
    A = torch.zeros(N, N)

    for n in range(N):
        for k in range(N):
            if n > k:
                A[n, k] = -((2*n + 1) ** 0.5) * ((2*k + 1) ** 0.5)
            elif n == k:
                A[n, k] = -(n + 1)

    return A

19.4.2 Diagonal Approximation

Full matrix \(A\) is expensive. S4 diagonalizes:

\[ A = V \Lambda V^{-1} \]

Then the recurrence becomes element-wise:

\[ h_t = \Lambda \odot h_{t-1} + \tilde{B} x_t \]

Complexity: O(n) instead of O(nd²) per layer.

19.5 Mamba: Selective State Spaces

19.5.1 The Selectivity Problem

Standard SSMs have fixed \(A, B, C\)—they can’t adapt to input content.

Attention advantage: Query-key matching is content-dependent.

Mamba insight: Make SSM parameters input-dependent.

19.5.2 Input-Dependent Parameters

class SelectiveSSM(nn.Module):
    """
    Mamba's selective state space layer.

    Key innovation: B, C, and delta depend on input.
    """
    def __init__(self, d_model, d_state):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state

        # A is still fixed (for stability)
        self.A_log = nn.Parameter(torch.randn(d_state))

        # B, C, delta are input-dependent
        self.B_proj = nn.Linear(d_model, d_state)
        self.C_proj = nn.Linear(d_model, d_state)
        self.delta_proj = nn.Linear(d_model, 1)

        self.D = nn.Parameter(torch.ones(d_model))

    def forward(self, x):
        batch, seq_len, d = x.shape

        # Input-dependent parameters
        B = self.B_proj(x)  # [batch, seq, d_state]
        C = self.C_proj(x)  # [batch, seq, d_state]
        delta = F.softplus(self.delta_proj(x))  # [batch, seq, 1]

        # Fixed A (log-parameterized for stability)
        A = -torch.exp(self.A_log)  # [d_state]

        # Discretize with input-dependent delta
        A_bar = torch.exp(delta * A)  # [batch, seq, d_state]
        B_bar = delta * B  # Simplified discretization

        # Selective scan (parallel algorithm)
        y = selective_scan(x, A_bar, B_bar, C, self.D)

        return y

19.5.3 The Selective Scan

With input-dependent parameters, we can’t use FFT convolution. Mamba uses a custom parallel scan:

def selective_scan(x, A, B, C, D):
    """
    Parallel selective scan algorithm.

    Uses associative scan for O(n) parallel complexity.
    """
    batch, seq_len, d_model = x.shape
    d_state = A.shape[-1]

    # Initialize
    h = torch.zeros(batch, d_state, device=x.device)
    outputs = []

    # Sequential for illustration (actual impl uses parallel scan)
    for t in range(seq_len):
        # State update: h = A[t] * h + B[t] * x[t]
        h = A[:, t] * h + B[:, t] * x[:, t, :, None]

        # Output: y = C[t] * h + D * x[t]
        y = (C[:, t] * h).sum(-1) + D * x[:, t]
        outputs.append(y)

    return torch.stack(outputs, dim=1)

Parallel scan: The recurrence \(h_t = a_t h_{t-1} + b_t\) is associative and can be parallelized:

def parallel_scan(a, b):
    """
    Parallel prefix scan for linear recurrence.

    h[t] = a[t] * h[t-1] + b[t]

    Uses O(log n) parallel steps instead of O(n) sequential.
    """
    # Combine pairs
    # (a1, b1) ⊕ (a2, b2) = (a1*a2, a2*b1 + b2)

    # This is associative, enabling parallel reduction
    pass  # Actual implementation uses custom CUDA kernels

19.6 Mamba Architecture

19.6.1 The Full Block

class MambaBlock(nn.Module):
    """
    Full Mamba block with gating.
    """
    def __init__(self, d_model, d_state=16, expand=2):
        super().__init__()
        d_inner = d_model * expand

        # Input projection
        self.in_proj = nn.Linear(d_model, d_inner * 2)

        # Conv for local context
        self.conv1d = nn.Conv1d(
            d_inner, d_inner,
            kernel_size=4,
            groups=d_inner,
            padding=3
        )

        # Selective SSM
        self.ssm = SelectiveSSM(d_inner, d_state)

        # Output projection
        self.out_proj = nn.Linear(d_inner, d_model)

    def forward(self, x):
        # x: [batch, seq, d_model]

        # Split into two paths
        xz = self.in_proj(x)
        x, z = xz.chunk(2, dim=-1)

        # Conv + SSM path
        x = x.transpose(1, 2)  # [batch, d_inner, seq]
        x = self.conv1d(x)[:, :, :x.shape[-1]]
        x = x.transpose(1, 2)
        x = F.silu(x)
        x = self.ssm(x)

        # Gating
        z = F.silu(z)
        x = x * z

        # Output
        return self.out_proj(x)

19.6.2 Mamba vs Transformer Comparison

Mamba-3B vs Transformer-3B:

                    Mamba       Transformer
──────────────────────────────────────────────
Inference (1K ctx)  1.0x        1.0x
Inference (8K ctx)  1.0x        2.5x (KV cache)
Inference (64K ctx) 1.0x        10x+ (KV cache explosion)

Training (1K)       0.9x        1.0x
Training (8K)       0.7x        1.0x
Training (64K)      0.5x        3.0x (memory bound)

Quality (avg benchmarks): ~95% of equivalent transformer

19.7 RWKV: RNN Reinvented

19.7.1 The RWKV Approach

RWKV (Receptance Weighted Key Value) is another linear-time architecture:

class RWKVBlock(nn.Module):
    """
    RWKV time-mixing block.

    Linear attention with learned decay.
    """
    def __init__(self, d_model):
        super().__init__()

        # Learnable decay
        self.time_decay = nn.Parameter(torch.randn(d_model))
        self.time_first = nn.Parameter(torch.randn(d_model))

        # Projections
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        self.receptance = nn.Linear(d_model, d_model)
        self.output = nn.Linear(d_model, d_model)

    def forward(self, x, state=None):
        batch, seq_len, d = x.shape

        # Time mixing with learned interpolation
        if state is None:
            state = torch.zeros(batch, d, device=x.device)

        k = self.key(x)
        v = self.value(x)
        r = torch.sigmoid(self.receptance(x))

        outputs = []
        for t in range(seq_len):
            # WKV computation with decay
            wkv = self.time_first * k[:, t] * v[:, t] + state
            state = torch.exp(-torch.exp(self.time_decay)) * state + k[:, t] * v[:, t]

            # Gated output
            out = r[:, t] * wkv
            outputs.append(out)

        return self.output(torch.stack(outputs, dim=1)), state

19.7.2 RWKV Advantages

  1. True RNN inference: O(1) memory per token
  2. Parallel training: Can be formulated as convolution
  3. Competitive quality: Matches transformers up to 14B scale

19.8 Hybrid Architectures

19.8.1 The Best of Both Worlds

Pure SSMs sometimes underperform on tasks requiring precise retrieval. Solution: combine SSM and attention.

class JambaBlock(nn.Module):
    """
    Jamba-style hybrid: SSM + Attention + MoE.

    Pattern: 7 Mamba blocks, then 1 Attention block.
    """
    def __init__(self, d_model, use_attention=False):
        super().__init__()

        if use_attention:
            self.mixer = MultiHeadAttention(d_model)
        else:
            self.mixer = MambaBlock(d_model)

        self.norm = RMSNorm(d_model)
        self.ffn = MoEFFN(d_model)  # Can also be dense

    def forward(self, x):
        x = x + self.mixer(self.norm(x))
        x = x + self.ffn(self.norm(x))
        return x

class JambaModel(nn.Module):
    def __init__(self, d_model, num_layers=32):
        super().__init__()

        self.layers = nn.ModuleList()
        for i in range(num_layers):
            # Every 8th layer is attention
            use_attention = (i % 8 == 7)
            self.layers.append(JambaBlock(d_model, use_attention))

19.8.2 StripedHyena

Another hybrid approach: interleaved SSM and attention at finer granularity.

class StripedHyenaBlock(nn.Module):
    """
    Alternating Hyena (SSM-like) and Attention.
    """
    def __init__(self, d_model, layer_idx):
        super().__init__()

        # Alternate between Hyena and Attention
        if layer_idx % 2 == 0:
            self.mixer = HyenaOperator(d_model)
        else:
            self.mixer = MultiHeadAttention(d_model)

19.9 When to Use SSMs

19.9.1 SSM Advantages

  1. O(n) complexity: Linear scaling with sequence length
  2. O(1) inference memory: Fixed state size, no KV cache
  3. Long sequences: Naturally handles 100K+ tokens
  4. Continuous signals: Better for time series, audio, video

19.9.2 SSM Limitations

  1. Retrieval tasks: Attention excels at “needle in haystack”
  2. Copying/recall: SSMs struggle with exact reproduction
  3. Maturity: Less tooling, fewer optimized kernels
  4. Quality ceiling: May not match best transformers (yet)

19.9.3 Decision Framework

Use SSM when:
  ✓ Very long sequences (>64K tokens)
  ✓ Streaming/real-time inference
  ✓ Memory-constrained deployment
  ✓ Continuous signals (audio, sensors)
  ✓ Latency matters more than peak quality

Use Transformer when:
  ✓ Retrieval-heavy tasks (RAG, QA)
  ✓ Maximum quality required
  ✓ Moderate sequence lengths (<32K)
  ✓ Existing infrastructure/tooling

Use Hybrid when:
  ✓ Long context + some retrieval needed
  ✓ Want SSM efficiency with attention accuracy
  ✓ Production systems where both matter

19.10 Implementation Tips

19.10.1 Efficient Mamba Kernels

# Use the official Mamba CUDA kernels
from mamba_ssm import Mamba

model = Mamba(
    d_model=2048,
    d_state=16,
    d_conv=4,
    expand=2
)

# Or use the pure PyTorch fallback (slower but portable)
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

19.10.2 State Caching for Inference

class MambaCache:
    """
    Cache SSM state for efficient generation.
    """
    def __init__(self, model, batch_size):
        self.states = []
        for layer in model.layers:
            # Each layer has fixed-size state
            state = torch.zeros(batch_size, layer.d_state)
            self.states.append(state)

    def update(self, layer_idx, new_state):
        self.states[layer_idx] = new_state

    def get(self, layer_idx):
        return self.states[layer_idx]

19.11 Key Takeaways

  1. SSMs are O(n): Linear complexity via state-space formulation.

  2. Training vs inference: Convolution for parallel training, recurrence for efficient inference.

  3. Mamba = selective SSM: Input-dependent parameters enable content-aware processing.

  4. Hybrids may be optimal: Combining SSM efficiency with attention precision.

  5. Different strengths: SSMs for length/efficiency, attention for retrieval/quality.

  6. Rapidly evolving: Mamba-2, RWKV-6, and new architectures emerging regularly.

NoteTry It Yourself

The accompanying notebook lets you:

  • Implement a basic SSM from scratch
  • Compare SSM vs attention complexity empirically
  • Experiment with Mamba on simple tasks
  • Build a hybrid architecture

Open In Colab

19.12 Further Reading

  • Gu et al. (2022). “Efficiently Modeling Long Sequences with Structured State Spaces” (S4)
  • Gu & Dao (2023). “Mamba: Linear-Time Sequence Modeling with Selective State Spaces”
  • Peng et al. (2023). “RWKV: Reinventing RNNs for the Transformer Era”
  • Lieber et al. (2024). “Jamba: A Hybrid Transformer-Mamba Language Model”
  • Dao & Gu (2024). “Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality” (Mamba-2)