33  Writing Fast Kernels with Triton

From Algorithm to Implementation

FlashAttention is brilliant. But how do you actually implement it?

This chapter bridges the gap between algorithm and code—using Triton, the Python-based GPU kernel language that’s revolutionizing custom op development.

33.1 Why Triton?

Traditional approach: Write CUDA C++ - Control everything (registers, shared memory, thread indices) - Very fast when optimized - Very complex (hundreds of lines for simple operations) - Architecture-specific (A100 code ≠ H100 code)

Triton approach: Write Python-like code - Compiler handles low-level details - Surprisingly fast (within 10% of hand-tuned CUDA) - Much simpler (10× less code) - Portable across GPU architectures

Example: Fused element-wise operations

CUDA (~100 lines):

__global__ void fused_kernel(float* out, float* a, float* b, int N) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < N) {
        float x = a[idx];
        float y = b[idx];
        out[idx] = exp(x + y) / (1.0f + exp(x + y));  // fused ops
    }
}
// Plus host code to launch...

Triton (~15 lines):

@triton.jit
def fused_kernel(out_ptr, a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < N

    a = tl.load(a_ptr + offsets, mask=mask)
    b = tl.load(b_ptr + offsets, mask=mask)

    result = tl.exp(a + b) / (1.0 + tl.exp(a + b))

    tl.store(out_ptr + offsets, result, mask=mask)

Triton wins on developer productivity. Let’s learn it.

33.2 Triton Basics: The Execution Model

Triton uses a block-based execution model, similar to CUDA but higher-level.

Key concepts: - Program: A grid of programs (like CUDA blocks) - Program ID: Which program instance is executing - Block: A tile of data processed together - Mask: Handle boundary conditions

33.2.1 Your First Triton Kernel: Vector Add

import triton
import triton.language as tl

@triton.jit
def add_kernel(
    x_ptr,  # Pointer to first input
    y_ptr,  # Pointer to second input
    out_ptr,  # Pointer to output
    N,  # Size of vectors
    BLOCK_SIZE: tl.constexpr,  # Compile-time constant
):
    # Get program ID (which block are we?)
    pid = tl.program_id(0)

    # Compute offsets for this block
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)

    # Create mask for bounds checking
    mask = offsets < N

    # Load data
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)

    # Compute
    output = x + y

    # Store result
    tl.store(out_ptr + offsets, output, mask=mask)

Launching the kernel:

def add(x: torch.Tensor, y: torch.Tensor):
    output = torch.empty_like(x)
    N = output.numel()

    # Calculate grid size (how many blocks?)
    BLOCK_SIZE = 1024
    grid = (triton.cdiv(N, BLOCK_SIZE),)

    # Launch kernel
    add_kernel[grid](
        x, y, output, N,
        BLOCK_SIZE=BLOCK_SIZE
    )
    return output

33.2.2 Understanding the Block Model

Input array: [0, 1, 2, ..., 4095]  (N = 4096)
BLOCK_SIZE = 1024

Program 0: processes elements [0:1024]
Program 1: processes elements [1024:2048]
Program 2: processes elements [2048:3072]
Program 3: processes elements [3072:4096]

Each program loads a block, computes, stores.

Key difference from CUDA: You think in blocks of data, not individual threads.

33.3 Memory Operations: Load and Store

33.3.1 Basic Loads

# Contiguous load
data = tl.load(ptr + offsets)

# Load with mask (for boundaries)
data = tl.load(ptr + offsets, mask=mask, other=0.0)
# Elements where mask=False get 0.0

33.3.2 Stores

# Contiguous store
tl.store(ptr + offsets, data)

# Store with mask
tl.store(ptr + offsets, data, mask=mask)
# Only stores where mask=True

33.3.3 2D Memory Access

# Load a 2D block
BLOCK_M, BLOCK_N = 32, 32

row_offsets = tl.arange(0, BLOCK_M)[:, None]  # Column vector
col_offsets = tl.arange(0, BLOCK_N)[None, :]  # Row vector

offsets_2d = row_offsets * stride + col_offsets  # 2D grid

data = tl.load(ptr + offsets_2d)

33.4 Worked Example 1: Softmax

Let’s implement a numerically stable softmax.

Algorithm: 1. Find max (for numerical stability) 2. Compute exp(x - max) 3. Sum exponents 4. Divide

@triton.jit
def softmax_kernel(
    output_ptr, input_ptr,
    input_row_stride,
    output_row_stride,
    n_cols,
    BLOCK_SIZE: tl.constexpr
):
    # Each program handles one row
    row_idx = tl.program_id(0)

    # Compute offsets for this row
    row_start = row_idx * input_row_stride
    col_offsets = tl.arange(0, BLOCK_SIZE)
    input_ptrs = input_ptr + row_start + col_offsets

    # Mask for this row
    mask = col_offsets < n_cols

    # Load input row
    row = tl.load(input_ptrs, mask=mask, other=-float('inf'))

    # Step 1: Find max
    row_max = tl.max(row, axis=0)

    # Step 2: Subtract max and exponentiate
    numerator = tl.exp(row - row_max)

    # Step 3: Sum
    denominator = tl.sum(numerator, axis=0)

    # Step 4: Divide
    softmax_output = numerator / denominator

    # Store result
    output_row_start = row_idx * output_row_stride
    output_ptrs = output_ptr + output_row_start + col_offsets
    tl.store(output_ptrs, softmax_output, mask=mask)

Launching:

def softmax(x):
    n_rows, n_cols = x.shape
    output = torch.empty_like(x)

    BLOCK_SIZE = triton.next_power_of_2(n_cols)
    grid = (n_rows,)

    softmax_kernel[grid](
        output, x,
        x.stride(0), output.stride(0),
        n_cols,
        BLOCK_SIZE=BLOCK_SIZE
    )
    return output

Performance: Within 5% of cuDNN softmax!

33.5 Worked Example 2: Fused Layer Norm

Layer norm: normalize each row to mean=0, std=1, then scale/shift.

Formula: \[\text{LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sigma} + \beta\]

Naive approach: 4 kernels (mean, variance, normalize, scale)

Fused approach: 1 kernel

@triton.jit
def layer_norm_kernel(
    output_ptr, input_ptr, gamma_ptr, beta_ptr,
    input_row_stride, output_row_stride,
    n_cols, eps,
    BLOCK_SIZE: tl.constexpr
):
    row_idx = tl.program_id(0)

    # Offsets
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < n_cols

    # Load input
    input_ptrs = input_ptr + row_idx * input_row_stride + col_offsets
    row = tl.load(input_ptrs, mask=mask, other=0.0)

    # Compute mean
    mean = tl.sum(row, axis=0) / n_cols

    # Compute variance
    variance = tl.sum((row - mean) * (row - mean), axis=0) / n_cols

    # Normalize
    rstd = 1.0 / tl.sqrt(variance + eps)
    normalized = (row - mean) * rstd

    # Load gamma and beta
    gamma = tl.load(gamma_ptr + col_offsets, mask=mask)
    beta = tl.load(beta_ptr + col_offsets, mask=mask)

    # Scale and shift
    output = gamma * normalized + beta

    # Store
    output_ptrs = output_ptr + row_idx * output_row_stride + col_offsets
    tl.store(output_ptrs, output, mask=mask)

Speedup over PyTorch: 2-3× (eliminates memory round-trips)

33.6 Worked Example 3: Fused Attention (Simplified)

Let’s implement a simplified FlashAttention (single-head, no masking).

Key ideas: 1. Tile Q into blocks (rows) 2. Tile K/V into blocks (columns) 3. For each Q block, iterate over K/V blocks 4. Accumulate attention output with online softmax

@triton.jit
def flash_attention_kernel(
    Q, K, V, Out,
    stride_qm, stride_qk,
    stride_km, stride_kk,
    stride_vm, stride_vk,
    stride_om, stride_ok,
    M, N, D,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_D: tl.constexpr,
):
    # Each program handles one block of Q rows
    pid = tl.program_id(0)

    # Q block row offsets
    offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_d = tl.arange(0, BLOCK_D)

    # Load Q block
    Q_block = tl.load(Q + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk)

    # Initialize output accumulators
    out_block = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32)
    max_score = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
    sum_exp = tl.zeros([BLOCK_M], dtype=tl.float32)

    # Iterate over K/V blocks
    for start_n in range(0, N, BLOCK_N):
        offs_n = start_n + tl.arange(0, BLOCK_N)

        # Load K block
        K_block = tl.load(K + offs_n[:, None] * stride_km + offs_d[None, :] * stride_kk)

        # Compute attention scores: Q @ K.T
        scores = tl.dot(Q_block, tl.trans(K_block))  # [BLOCK_M, BLOCK_N]

        # Update running max
        block_max = tl.max(scores, axis=1)
        new_max = tl.maximum(max_score, block_max)

        # Rescale old statistics
        scale_old = tl.exp(max_score - new_max)
        scale_new = tl.exp(block_max - new_max)

        # Update softmax denominator
        exp_scores = tl.exp(scores - new_max[:, None])
        sum_exp = sum_exp * scale_old + tl.sum(exp_scores, axis=1) * scale_new

        # Load V block
        V_block = tl.load(V + offs_n[:, None] * stride_vm + offs_d[None, :] * stride_vk)

        # Update output: weighted sum
        out_block = out_block * scale_old[:, None] + tl.dot(exp_scores, V_block) * scale_new[:, None]
        max_score = new_max

    # Normalize
    out_block = out_block / sum_exp[:, None]

    # Store output
    tl.store(Out + offs_m[:, None] * stride_om + offs_d[None, :] * stride_ok, out_block)

This is the core of FlashAttention!

Key differences from standard attention: - Never materializes full attention matrix (O(n²)) - Uses online softmax updates (max, sum rescaling) - Processes in tiles that fit in shared memory

33.7 Matrix Multiply in Triton

Triton has a special tl.dot instruction that maps to tensor cores.

@triton.jit
def matmul_kernel(
    A_ptr, B_ptr, C_ptr,
    M, N, K,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    # Compute offsets
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)

    # Initialize accumulator
    acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)

    # Iterate over K dimension in blocks
    for k in range(0, K, BLOCK_K):
        # Load A block
        a = tl.load(A_ptr + offs_m[:, None] * stride_am + (offs_k[None, :] + k) * stride_ak)

        # Load B block
        b = tl.load(B_ptr + (offs_k[:, None] + k) * stride_bk + offs_n[None, :] * stride_bn)

        # Multiply-accumulate (uses tensor cores!)
        acc += tl.dot(a, b)

    # Store result
    tl.store(C_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn, acc)

Performance: 80-90% of cuBLAS on A100!

33.8 Optimization Techniques

33.8.1 1. Choose Block Sizes Wisely

Rule of thumb: - Power of 2: Always (16, 32, 64, 128, 256) - Multiple of warp size (32): Always - Tile fits in shared memory: Check limits

# A100: 164 KB shared memory per SM
BLOCK_M = 128
BLOCK_N = 128
BLOCK_K = 32

# FP16 data:
# A tile: 128 × 32 × 2 = 8 KB
# B tile: 32 × 128 × 2 = 8 KB
# C tile: 128 × 128 × 4 = 64 KB (FP32 accumulator)
# Total: 80 KB ✓ Fits!

33.8.2 2. Use Compile-Time Constants

# BAD: Runtime constant
def kernel(..., block_size):
    offsets = tl.arange(0, block_size)  # Slow!

# GOOD: Compile-time constant
def kernel(..., BLOCK_SIZE: tl.constexpr):
    offsets = tl.arange(0, BLOCK_SIZE)  # Fast!

Triton generates optimized code for each BLOCK_SIZE value.

33.8.3 3. Memory Coalescing

# BAD: Strided access
for i in range(N):
    data = tl.load(ptr + i * large_stride)

# GOOD: Contiguous access
offsets = tl.arange(0, BLOCK)
data = tl.load(ptr + offsets)

33.8.4 4. Reduce Synchronization

Triton handles synchronization automatically for tl.sum, tl.max, etc. But minimize reductions:

# BAD: Multiple reductions
max_val = tl.max(x)
sum_val = tl.sum(x)

# BETTER: Combined when possible
# (depends on algorithm)

33.8.5 5. Profile with Nsight

# Generate profile
ncu --set full -o profile python script.py

# Look for:
# - SM efficiency (target: >80%)
# - Memory throughput (target: >70%)
# - Warp execution efficiency (target: >90%)

33.8.6 6. Benchmark with triton.testing

Triton includes a robust benchmarking module that follows GPU benchmarking best practices:

import triton

# Simple timing (returns min time in ms by default)
ms = triton.testing.do_bench(lambda: kernel[grid](a, b, c))

# With all statistics
ms = triton.testing.do_bench(
    lambda: kernel[grid](a, b, c),
    warmup=25,       # Warmup iterations
    rep=100,         # Timed repetitions
    return_mode='median',  # 'min', 'max', 'mean', 'median', 'all'
)

For systematic benchmarking across problem sizes:

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['M', 'N'],
        x_vals=[(512, 512), (1024, 1024), (2048, 2048)],
        line_arg='provider',
        line_vals=['triton', 'torch'],
        line_names=['Triton', 'PyTorch'],
        xlabel='Matrix Size',
        ylabel='TFLOPS',
    )
)
def benchmark_matmul(M, N, provider):
    K = 512
    a = torch.randn(M, K, device='cuda')
    b = torch.randn(K, N, device='cuda')

    if provider == 'triton':
        fn = lambda: triton_matmul(a, b)
    else:
        fn = lambda: torch.mm(a, b)

    ms = triton.testing.do_bench(fn)
    flops = 2 * M * N * K
    return flops / ms / 1e9  # TFLOPS

benchmark_matmul.run(print_data=True, save_path='benchmark.png')

This automatically generates comparison plots. See Chapter 15: The Art of Measurement for a deeper discussion of GPU benchmarking methodology.

33.9 Debugging Triton Kernels

33.9.2 Check for NaNs

result = tl.load(...)
tl.device_assert(tl.sum(tl.where(result != result, 1, 0)) == 0, "NaN detected!")

33.9.3 Visualize Generated Code

# See LLVM IR
kernel.src  # Shows intermediate representation

# See PTX assembly
kernel.asm  # Shows GPU assembly

33.10 When to Use Triton

Use Triton when: - You need fusion of operations (avoiding memory round-trips) - You’re implementing a custom algorithm (novel attention variants) - PyTorch native ops are too slow - You want cross-GPU portability

Don’t use Triton when: - PyTorch native ops are already fast enough - cuBLAS/cuDNN solve your problem (they’re heavily optimized) - You need the last 5-10% of performance (hand-tuned CUDA wins)

33.11 Performance Comparison

Softmax (1024 × 1024): - PyTorch native: 0.12 ms - Triton fused: 0.11 ms - Speedup: 1.1× (minimal overhead)

Layer Norm (4096 × 4096): - PyTorch native: 0.45 ms - Triton fused: 0.18 ms - Speedup: 2.5× (memory round-trip elimination)

FlashAttention (2048 seq, 64 dim): - PyTorch native: 8.2 ms - Triton FlashAttention: 1.1 ms - Speedup: 7.5× (tiling + fusion)

Matrix Multiply (4096 × 4096 × 4096): - cuBLAS: 3.2 ms - Triton: 3.8 ms - Slowdown: 1.2× (cuBLAS is better, use it!)

33.12 Common Pitfalls

33.12.1 1. Forgetting Masks

# BAD: No mask, reads beyond bounds
data = tl.load(ptr + offsets)

# GOOD: Mask for safety
mask = offsets < N
data = tl.load(ptr + offsets, mask=mask)

33.12.2 2. Wrong Stride Computation

# If tensor is [M, N] row-major:
# stride(0) = N  (number of elements to next row)
# stride(1) = 1  (number of elements to next column)

# BAD
offset = row * 1 + col * N

# GOOD
offset = row * stride_m + col * stride_n

33.12.3 3. Block Size Mismatch

# Kernel expects BLOCK_SIZE = 1024
BLOCK_SIZE = 512  # Mismatch! Will fail

Use tl.constexpr to catch at compile time.

33.13 Triton 3.0 and Beyond

33.13.1 Multi-Backend Support

Triton now supports multiple GPU backends, not just NVIDIA:

# AMD GPUs via ROCm
TRITON_HIP=1 python my_triton_script.py

# Intel GPUs via XPU
TRITON_XPU=1 python my_triton_script.py

# Same kernel works on all!
@triton.jit
def my_kernel(...):
    # Identical code
    pass

Supported backends (Triton 3.x):

Backend         Status          Notes
─────────────────────────────────────────────────
NVIDIA (CUDA)   Stable          Primary, best optimized
AMD (ROCm)      Stable          MI300X, MI250X support
Intel (XPU)     Experimental    Max/Flex GPUs
CPU             Experimental    For debugging

33.13.2 Autotuning Improvements

Triton 3.0 has better autotuning infrastructure:

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=5, num_warps=2),
    ],
    key=['M', 'N', 'K'],
    prune_configs_by={
        'early_config_prune': matmul_config_pruner,  # Custom pruning
        'perf_model': estimate_matmul_time,  # Performance model
    },
    warmup=25,
    rep=100,
)
@triton.jit
def matmul_kernel(...):
    ...

New autotuning features: - num_stages: Pipeline depth (more overlapping) - num_warps: Warps per program (occupancy tuning) - Custom pruning to skip obviously bad configs - Performance models to guide search

33.13.3 Descriptor-Based Loads (Experimental)

Triton is adding TMA-like descriptor loads for Hopper:

@triton.jit
def kernel_with_descriptors(
    A_ptr, B_ptr, C_ptr,
    ...
    # Descriptor enables hardware-managed loads
    A_desc: tl.tensor_descriptor,
):
    # Hardware manages the load
    a_tile = tl.load_tile(A_desc, [block_m_idx, block_k_idx])

    # vs. traditional pointer-based
    a_ptr_offset = A_ptr + offsets_m[:, None] * stride_am + offsets_k[None, :] * stride_ak
    a_tile_ptr = tl.load(a_ptr_offset)

Benefits: - Hardware handles address calculation - Better memory access patterns - Reduced register pressure

33.13.4 Persistent Kernels

For small operations, kernel launch overhead dominates. Persistent kernels help:

@triton.jit
def persistent_kernel(
    A_ptr, B_ptr, C_ptr, D_ptr,
    num_tasks,
    BLOCK_SIZE: tl.constexpr,
):
    # Get program ID
    pid = tl.program_id(0)

    # Persistent: each program processes multiple tiles
    num_programs = tl.num_programs(0)

    for task_id in range(pid, num_tasks, num_programs):
        # Compute which tile this is
        tile_m = task_id // num_tiles_n
        tile_n = task_id % num_tiles_n

        # Process tile
        offs_m = tile_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
        offs_n = tile_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)

        # Load, compute, store...

Use case: Many small operations (MoE expert computations, sparse ops)

33.13.5 Integration with torch.compile

Triton kernels now integrate seamlessly with torch.compile:

import torch
import triton
import triton.language as tl

@triton.jit
def my_triton_kernel(...):
    ...

# Register as custom op
torch.library.custom_op("mylib::my_op", my_triton_function)

# Torch.compile will use it
@torch.compile
def my_model(x):
    return torch.ops.mylib.my_op(x)

# Or use Triton directly via inductor
# torch.compile generates Triton for fused ops automatically!

TorchInductor generates Triton:

# This PyTorch code...
y = torch.relu(x @ weight + bias)

# ...becomes this Triton (automatically!)
@triton.jit
def fused_matmul_bias_relu(...)
    ...

33.13.6 Triton Distributed (Experimental)

For multi-GPU kernels with overlapped communication:

from triton.distributed import (
    DistributedTensor,
    all_gather,
    reduce_scatter,
)

@triton.jit
def distributed_matmul_kernel(
    A: DistributedTensor,  # Sharded across GPUs
    B: DistributedTensor,
    C: DistributedTensor,
    ...
):
    # All-gather A across GPUs
    A_full = all_gather(A, dim=0)

    # Local matmul
    C_local = tl.dot(A_full, B)

    # Reduce-scatter result
    reduce_scatter(C_local, C, dim=0)

Use case: Custom distributed ops that overlap compute and communication.

33.13.7 Best Practices for Triton 3.0

# 1. Use autotuning for performance-critical kernels
@triton.autotune(configs=[...], key=['M', 'N'])
@triton.jit
def kernel(...):
    ...

# 2. Specify precision explicitly
acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)  # FP32 accumulator
result = acc.to(tl.float16)  # Convert at end

# 3. Use tl.dot for matrix ops (uses tensor cores)
c = tl.dot(a, b)  # Good
c = a @ b  # Also works but less explicit

# 4. Profile with triton.testing
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['M', 'N'],
        x_vals=[128 * i for i in range(2, 33)],
        ...
    )
)
def benchmark(M, N):
    ...

33.14 Connections

Chapter 4 (GPU Architecture): Triton abstracts GPU details but understanding grids/blocks/shared memory helps.

Chapter 8 (Fusion): Triton enables fusion by writing multi-op kernels.

Chapter 10 (FlashAttention): Now you can implement it!

Chapter 19 (Profiling): Use Nsight to verify Triton kernel performance.

33.15 Key Takeaways

  1. Triton makes GPU programming accessible: Python-like syntax, automatic optimization.

  2. Think in blocks: Load a block, compute, store—not individual threads.

  3. Fusion is the main benefit: Eliminate memory round-trips for multi-op sequences.

  4. Use compile-time constants: tl.constexpr enables optimization.

  5. Multi-backend support: Same kernel runs on NVIDIA, AMD, and Intel GPUs.

  6. Autotuning is essential: Use @triton.autotune for production kernels.

  7. torch.compile generates Triton: TorchInductor uses Triton under the hood.

  8. Profile before claiming victory: Verify performance with Nsight.

  9. Don’t replace optimized libraries: cuBLAS and cuDNN are still king for basic ops.


TipLearning Path

Week 1: Vector operations (add, mul, ReLU) Week 2: Reductions (softmax, layer norm) Week 3: Matrix multiply basics Week 4: Fused operations (bias + activation, etc.) Week 5: FlashAttention implementation

By week 5, you’ll be writing production-quality custom kernels.


33.16 Further Reading