Skip to content

Section 5.8: Implementation — Building Attention from Scratch

Reading time: 25 minutes | Difficulty: ★★★★☆

This section brings together everything we've learned into a complete, working implementation of multi-head causal self-attention. We'll build each component from first principles.

Complete Architecture Overview

Input Embeddings + Positional Encoding
    ┌───────────────────┐
    │   Multi-Head      │
    │   Self-Attention  │◄─── Causal Mask
    └───────────────────┘
      Layer Norm + Residual
    ┌───────────────────┐
    │   Feed-Forward    │
    │   Network (FFN)   │
    └───────────────────┘
      Layer Norm + Residual
        Output

We'll implement each component, culminating in a complete attention layer.

Component 1: Scaled Dot-Product Attention

The core attention mechanism:

import numpy as np

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Scaled dot-product attention.

    Args:
        Q: Queries [..., seq_len, d_k]
        K: Keys [..., seq_len, d_k]
        V: Values [..., seq_len, d_v]
        mask: Optional mask [..., seq_len, seq_len]
              0 for allowed, -inf for masked

    Returns:
        output: Attention output [..., seq_len, d_v]
        weights: Attention weights [..., seq_len, seq_len]
    """
    d_k = Q.shape[-1]

    # Compute attention scores
    # Q: [..., n, d_k], K.T: [..., d_k, n] -> [..., n, n]
    scores = Q @ np.swapaxes(K, -2, -1) / np.sqrt(d_k)

    # Apply mask if provided
    if mask is not None:
        scores = scores + mask

    # Softmax over keys (last dimension)
    weights = softmax(scores, axis=-1)

    # Weighted sum of values
    output = weights @ V

    return output, weights


def softmax(x, axis=-1):
    """Numerically stable softmax."""
    # Handle -inf from masking
    x_max = np.max(np.where(np.isinf(x), -1e9, x), axis=axis, keepdims=True)
    exp_x = np.exp(x - x_max)
    exp_x = np.where(np.isinf(x), 0, exp_x)
    return exp_x / (np.sum(exp_x, axis=axis, keepdims=True) + 1e-10)

Component 2: Causal Mask

def create_causal_mask(seq_len):
    """
    Create causal (autoregressive) attention mask.

    Returns:
        mask: [seq_len, seq_len] with 0 for allowed, -inf for masked
    """
    # Upper triangular with -inf (positions that should be masked)
    mask = np.triu(np.ones((seq_len, seq_len)), k=1) * float('-inf')
    return mask

Component 3: Multi-Head Attention

class MultiHeadAttention:
    """
    Multi-head attention mechanism.

    Splits input into multiple heads, applies attention independently,
    then concatenates and projects.
    """

    def __init__(self, d_model, n_heads, dropout=0.0):
        """
        Initialize multi-head attention.

        Args:
            d_model: Model dimension
            n_heads: Number of attention heads
            dropout: Dropout probability (for training)
        """
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"

        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.dropout = dropout

        # Initialize projection matrices with Xavier/Glorot initialization
        scale = np.sqrt(2.0 / (d_model + self.d_k))

        # Combined QKV projection for efficiency
        self.W_qkv = np.random.randn(d_model, 3 * d_model) * scale

        # Output projection
        self.W_o = np.random.randn(d_model, d_model) * scale

        # For storing attention weights (useful for visualization)
        self.attention_weights = None

    def forward(self, x, mask=None):
        """
        Forward pass.

        Args:
            x: Input tensor [batch_size, seq_len, d_model] or [seq_len, d_model]
            mask: Optional attention mask

        Returns:
            output: [batch_size, seq_len, d_model] or [seq_len, d_model]
        """
        # Handle both batched and unbatched inputs
        if x.ndim == 2:
            x = x[np.newaxis, ...]  # Add batch dimension
            squeeze_batch = True
        else:
            squeeze_batch = False

        batch_size, seq_len, _ = x.shape

        # Project to Q, K, V
        qkv = x @ self.W_qkv  # [batch, seq_len, 3*d_model]

        # Split into Q, K, V
        Q, K, V = np.split(qkv, 3, axis=-1)  # Each [batch, seq_len, d_model]

        # Reshape for multi-head: [batch, seq_len, n_heads, d_k]
        Q = Q.reshape(batch_size, seq_len, self.n_heads, self.d_k)
        K = K.reshape(batch_size, seq_len, self.n_heads, self.d_k)
        V = V.reshape(batch_size, seq_len, self.n_heads, self.d_k)

        # Transpose to [batch, n_heads, seq_len, d_k]
        Q = np.transpose(Q, (0, 2, 1, 3))
        K = np.transpose(K, (0, 2, 1, 3))
        V = np.transpose(V, (0, 2, 1, 3))

        # Apply scaled dot-product attention
        attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)

        # Store for visualization
        self.attention_weights = attn_weights

        # Transpose back: [batch, seq_len, n_heads, d_k]
        attn_output = np.transpose(attn_output, (0, 2, 1, 3))

        # Concatenate heads: [batch, seq_len, d_model]
        attn_output = attn_output.reshape(batch_size, seq_len, self.d_model)

        # Output projection
        output = attn_output @ self.W_o

        if squeeze_batch:
            output = output[0]

        return output

    def parameters(self):
        """Return all parameters."""
        return [self.W_qkv, self.W_o]

Component 4: Positional Encoding

class SinusoidalPositionalEncoding:
    """
    Sinusoidal positional encoding from the original Transformer.
    """

    def __init__(self, max_len, d_model):
        """
        Initialize positional encoding.

        Args:
            max_len: Maximum sequence length
            d_model: Model dimension
        """
        self.d_model = d_model

        # Create encoding matrix
        self.encoding = self._create_encoding(max_len, d_model)

    def _create_encoding(self, max_len, d_model):
        """Generate sinusoidal positional encodings."""
        pe = np.zeros((max_len, d_model))

        position = np.arange(max_len)[:, np.newaxis]
        div_term = np.exp(np.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))

        pe[:, 0::2] = np.sin(position * div_term)
        pe[:, 1::2] = np.cos(position * div_term)

        return pe

    def forward(self, seq_len):
        """Get positional encoding for sequence length."""
        return self.encoding[:seq_len]


class LearnedPositionalEncoding:
    """
    Learned positional embeddings.
    """

    def __init__(self, max_len, d_model):
        """
        Initialize learned positional encoding.

        Args:
            max_len: Maximum sequence length
            d_model: Model dimension
        """
        self.max_len = max_len
        self.d_model = d_model

        # Learnable embeddings
        self.encoding = np.random.randn(max_len, d_model) * 0.02

    def forward(self, seq_len):
        """Get positional encoding for sequence length."""
        return self.encoding[:seq_len]

    def parameters(self):
        """Return parameters for learning."""
        return [self.encoding]

Component 5: Layer Normalization

class LayerNorm:
    """
    Layer normalization.

    Normalizes across the feature dimension, then applies learnable
    scale (gamma) and shift (beta).
    """

    def __init__(self, d_model, eps=1e-6):
        """
        Initialize layer normalization.

        Args:
            d_model: Model dimension
            eps: Small constant for numerical stability
        """
        self.d_model = d_model
        self.eps = eps

        # Learnable parameters
        self.gamma = np.ones(d_model)  # Scale
        self.beta = np.zeros(d_model)  # Shift

    def forward(self, x):
        """
        Apply layer normalization.

        Args:
            x: Input [..., d_model]

        Returns:
            Normalized output [..., d_model]
        """
        # Compute mean and variance along last dimension
        mean = np.mean(x, axis=-1, keepdims=True)
        var = np.var(x, axis=-1, keepdims=True)

        # Normalize
        x_norm = (x - mean) / np.sqrt(var + self.eps)

        # Scale and shift
        return self.gamma * x_norm + self.beta

    def parameters(self):
        """Return learnable parameters."""
        return [self.gamma, self.beta]

Component 6: Feed-Forward Network

class FeedForward:
    """
    Position-wise feed-forward network.

    Two linear layers with activation in between.
    FFN(x) = W2 * activation(W1 * x + b1) + b2
    """

    def __init__(self, d_model, d_ff=None, activation='gelu'):
        """
        Initialize feed-forward network.

        Args:
            d_model: Model dimension
            d_ff: Hidden dimension (default: 4 * d_model)
            activation: Activation function ('relu' or 'gelu')
        """
        self.d_model = d_model
        self.d_ff = d_ff or 4 * d_model
        self.activation = activation

        # Initialize weights
        scale1 = np.sqrt(2.0 / (d_model + self.d_ff))
        scale2 = np.sqrt(2.0 / (self.d_ff + d_model))

        self.W1 = np.random.randn(d_model, self.d_ff) * scale1
        self.b1 = np.zeros(self.d_ff)
        self.W2 = np.random.randn(self.d_ff, d_model) * scale2
        self.b2 = np.zeros(d_model)

    def _activation(self, x):
        """Apply activation function."""
        if self.activation == 'relu':
            return np.maximum(0, x)
        elif self.activation == 'gelu':
            # Approximate GELU
            return 0.5 * x * (1 + np.tanh(
                np.sqrt(2 / np.pi) * (x + 0.044715 * x**3)
            ))
        else:
            raise ValueError(f"Unknown activation: {self.activation}")

    def forward(self, x):
        """
        Forward pass.

        Args:
            x: Input [..., d_model]

        Returns:
            Output [..., d_model]
        """
        # First linear layer
        hidden = x @ self.W1 + self.b1

        # Activation
        hidden = self._activation(hidden)

        # Second linear layer
        output = hidden @ self.W2 + self.b2

        return output

    def parameters(self):
        """Return all parameters."""
        return [self.W1, self.b1, self.W2, self.b2]

Complete Transformer Block

class TransformerBlock:
    """
    Single Transformer block with:
    - Multi-head self-attention
    - Layer normalization
    - Feed-forward network
    - Residual connections
    """

    def __init__(self, d_model, n_heads, d_ff=None, dropout=0.0,
                 pre_norm=True):
        """
        Initialize Transformer block.

        Args:
            d_model: Model dimension
            n_heads: Number of attention heads
            d_ff: FFN hidden dimension
            dropout: Dropout probability
            pre_norm: If True, apply LayerNorm before sublayers (modern style)
        """
        self.d_model = d_model
        self.n_heads = n_heads
        self.pre_norm = pre_norm

        # Multi-head attention
        self.attention = MultiHeadAttention(d_model, n_heads, dropout)

        # Feed-forward network
        self.ffn = FeedForward(d_model, d_ff)

        # Layer normalization
        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)

    def forward(self, x, mask=None):
        """
        Forward pass.

        Args:
            x: Input [batch_size, seq_len, d_model]
            mask: Attention mask

        Returns:
            Output [batch_size, seq_len, d_model]
        """
        if self.pre_norm:
            # Pre-norm (modern, better training dynamics)
            # x = x + Attention(LayerNorm(x))
            normed = self.norm1.forward(x)
            attn_out = self.attention.forward(normed, mask)
            x = x + attn_out

            # x = x + FFN(LayerNorm(x))
            normed = self.norm2.forward(x)
            ffn_out = self.ffn.forward(normed)
            x = x + ffn_out
        else:
            # Post-norm (original Transformer)
            # x = LayerNorm(x + Attention(x))
            attn_out = self.attention.forward(x, mask)
            x = self.norm1.forward(x + attn_out)

            # x = LayerNorm(x + FFN(x))
            ffn_out = self.ffn.forward(x)
            x = self.norm2.forward(x + ffn_out)

        return x

    def parameters(self):
        """Return all parameters."""
        params = []
        params.extend(self.attention.parameters())
        params.extend(self.ffn.parameters())
        params.extend(self.norm1.parameters())
        params.extend(self.norm2.parameters())
        return params

Complete Causal Language Model

class CausalTransformer:
    """
    Causal Transformer language model.

    Combines:
    - Token embeddings
    - Positional encodings
    - Stack of Transformer blocks
    - Output projection
    """

    def __init__(self, vocab_size, d_model, n_heads, n_layers,
                 max_len=512, d_ff=None, dropout=0.0):
        """
        Initialize causal Transformer.

        Args:
            vocab_size: Vocabulary size
            d_model: Model dimension
            n_heads: Number of attention heads per layer
            n_layers: Number of Transformer blocks
            max_len: Maximum sequence length
            d_ff: FFN hidden dimension
            dropout: Dropout probability
        """
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.max_len = max_len

        # Token embeddings
        self.token_embedding = np.random.randn(vocab_size, d_model) * 0.02

        # Positional encoding
        self.pos_encoding = SinusoidalPositionalEncoding(max_len, d_model)

        # Transformer blocks
        self.blocks = [
            TransformerBlock(d_model, n_heads, d_ff, dropout)
            for _ in range(n_layers)
        ]

        # Final layer norm
        self.final_norm = LayerNorm(d_model)

        # Output projection (often tied with token embeddings)
        self.output_projection = self.token_embedding.T  # [d_model, vocab_size]

    def forward(self, token_ids):
        """
        Forward pass.

        Args:
            token_ids: Input token IDs [batch_size, seq_len] or [seq_len]

        Returns:
            logits: [batch_size, seq_len, vocab_size] or [seq_len, vocab_size]
        """
        # Handle unbatched input
        if token_ids.ndim == 1:
            token_ids = token_ids[np.newaxis, :]
            squeeze_batch = True
        else:
            squeeze_batch = False

        batch_size, seq_len = token_ids.shape

        # Token embeddings
        x = self.token_embedding[token_ids]  # [batch, seq_len, d_model]

        # Add positional encoding
        pos_enc = self.pos_encoding.forward(seq_len)
        x = x + pos_enc

        # Create causal mask
        mask = create_causal_mask(seq_len)

        # Apply Transformer blocks
        for block in self.blocks:
            x = block.forward(x, mask)

        # Final layer norm
        x = self.final_norm.forward(x)

        # Project to vocabulary
        logits = x @ self.output_projection  # [batch, seq_len, vocab_size]

        if squeeze_batch:
            logits = logits[0]

        return logits

    def generate(self, prompt_ids, max_new_tokens, temperature=1.0):
        """
        Generate tokens autoregressively.

        Args:
            prompt_ids: Initial token IDs [seq_len]
            max_new_tokens: Number of tokens to generate
            temperature: Sampling temperature

        Returns:
            Generated token IDs
        """
        tokens = list(prompt_ids)

        for _ in range(max_new_tokens):
            # Get current context (up to max_len)
            context = np.array(tokens[-self.max_len:])

            # Forward pass
            logits = self.forward(context)

            # Get logits for last position
            next_logits = logits[-1] / temperature

            # Sample from distribution
            probs = softmax(next_logits)
            next_token = np.random.choice(len(probs), p=probs)

            tokens.append(next_token)

        return tokens

    def parameters(self):
        """Return all parameters."""
        params = [self.token_embedding]
        for block in self.blocks:
            params.extend(block.parameters())
        params.extend(self.final_norm.parameters())
        return params

Usage Example

# Create a small model
model = CausalTransformer(
    vocab_size=1000,
    d_model=64,
    n_heads=4,
    n_layers=2,
    max_len=128
)

# Example input
token_ids = np.array([1, 5, 23, 7, 42])

# Forward pass
logits = model.forward(token_ids)
print(f"Input shape: {token_ids.shape}")
print(f"Output shape: {logits.shape}")  # [5, 1000]

# Generation
generated = model.generate(token_ids, max_new_tokens=10, temperature=0.8)
print(f"Generated {len(generated)} tokens")

# Visualize attention weights
block = model.blocks[0]
attention_weights = block.attention.attention_weights
print(f"Attention weights shape: {attention_weights.shape}")
# [batch=1, n_heads=4, seq_len=5, seq_len=5]

Visualizing Attention

import matplotlib.pyplot as plt

def visualize_attention(model, token_ids, tokens_str, layer=0, head=0):
    """
    Visualize attention patterns.

    Args:
        model: CausalTransformer
        token_ids: Input token IDs
        tokens_str: String representation of tokens
        layer: Which layer to visualize
        head: Which head to visualize
    """
    # Forward pass to populate attention weights
    _ = model.forward(token_ids)

    # Get attention weights from specified layer
    weights = model.blocks[layer].attention.attention_weights

    # Select specific head
    if weights.ndim == 4:  # [batch, heads, seq, seq]
        weights = weights[0, head]  # [seq, seq]
    elif weights.ndim == 3:  # [heads, seq, seq]
        weights = weights[head]

    # Plot
    fig, ax = plt.subplots(figsize=(8, 8))
    im = ax.imshow(weights, cmap='Blues')

    ax.set_xticks(range(len(tokens_str)))
    ax.set_yticks(range(len(tokens_str)))
    ax.set_xticklabels(tokens_str, rotation=45, ha='right')
    ax.set_yticklabels(tokens_str)

    ax.set_xlabel('Key positions')
    ax.set_ylabel('Query positions')
    ax.set_title(f'Attention weights (Layer {layer}, Head {head})')

    plt.colorbar(im)
    plt.tight_layout()
    plt.show()

Performance Considerations

Memory Usage

Model parameters:

- Token embeddings: vocab_size × d_model
- Per block:
  - QKV projection: d_model × 3d_model = 3d_model²
  - Output projection: d_model²
  - FFN: d_model × 4d_model × 2 = 8d_model²
  - Layer norms: 4d_model
  - Total per block: ~12d_model²

Activation memory (during forward):

- Attention scores: batch × n_heads × seq_len²
- Attention weights: batch × n_heads × seq_len²
- This is the O(n²) that limits context length!

Computational Cost

Per attention layer:

- QKV projection: O(seq_len × d_model²)
- Attention scores: O(seq_len² × d_model)
- Softmax: O(seq_len²)
- Attention @ V: O(seq_len² × d_model)
- Output projection: O(seq_len × d_model²)

Per FFN:

- First linear: O(seq_len × d_model × d_ff)
- Second linear: O(seq_len × d_ff × d_model)

Total per layer: O(seq_len × d_model² + seq_len² × d_model)

Summary

We've built a complete attention implementation with:

Component Purpose Key Details
Scaled dot-product attention Core attention computation \(QK^T\)/√d_k, softmax, × V
Multi-head attention Multiple attention patterns Split into h heads, concatenate
Positional encoding Position information Sinusoidal or learned
Layer normalization Training stability Normalize, scale, shift
Feed-forward network Feature processing Two layers with activation
Transformer block Combine components Attention + FFN + residuals
Causal Transformer Complete model Embeddings + blocks + output

Key takeaway: Building attention from scratch reveals how each component contributes to the whole. The core idea—computing relevance via dot products, normalizing with softmax, and aggregating values—is elegant and powerful. Understanding this implementation provides the foundation for working with any modern language model.

Continue to: Stage 6: The Complete Transformer (coming soon)