33 Writing Fast Kernels with Triton
From Algorithm to Implementation
FlashAttention is brilliant. But how do you actually implement it?
This chapter bridges the gap between algorithm and code—using Triton, the Python-based GPU kernel language that’s revolutionizing custom op development.
33.1 Why Triton?
Traditional approach: Write CUDA C++ - Control everything (registers, shared memory, thread indices) - Very fast when optimized - Very complex (hundreds of lines for simple operations) - Architecture-specific (A100 code ≠ H100 code)
Triton approach: Write Python-like code - Compiler handles low-level details - Surprisingly fast (within 10% of hand-tuned CUDA) - Much simpler (10× less code) - Portable across GPU architectures
Example: Fused element-wise operations
CUDA (~100 lines):
__global__ void fused_kernel(float* out, float* a, float* b, int N) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < N) {
float x = a[idx];
float y = b[idx];
out[idx] = exp(x + y) / (1.0f + exp(x + y)); // fused ops
}
}
// Plus host code to launch...
Triton (~15 lines):
@triton.jit
def fused_kernel(out_ptr, a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < N
a = tl.load(a_ptr + offsets, mask=mask)
b = tl.load(b_ptr + offsets, mask=mask)
result = tl.exp(a + b) / (1.0 + tl.exp(a + b))
tl.store(out_ptr + offsets, result, mask=mask)Triton wins on developer productivity. Let’s learn it.
33.2 Triton Basics: The Execution Model
Triton uses a block-based execution model, similar to CUDA but higher-level.
Key concepts: - Program: A grid of programs (like CUDA blocks) - Program ID: Which program instance is executing - Block: A tile of data processed together - Mask: Handle boundary conditions
33.2.1 Your First Triton Kernel: Vector Add
import triton
import triton.language as tl
@triton.jit
def add_kernel(
x_ptr, # Pointer to first input
y_ptr, # Pointer to second input
out_ptr, # Pointer to output
N, # Size of vectors
BLOCK_SIZE: tl.constexpr, # Compile-time constant
):
# Get program ID (which block are we?)
pid = tl.program_id(0)
# Compute offsets for this block
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
# Create mask for bounds checking
mask = offsets < N
# Load data
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
# Compute
output = x + y
# Store result
tl.store(out_ptr + offsets, output, mask=mask)Launching the kernel:
def add(x: torch.Tensor, y: torch.Tensor):
output = torch.empty_like(x)
N = output.numel()
# Calculate grid size (how many blocks?)
BLOCK_SIZE = 1024
grid = (triton.cdiv(N, BLOCK_SIZE),)
# Launch kernel
add_kernel[grid](
x, y, output, N,
BLOCK_SIZE=BLOCK_SIZE
)
return output33.2.2 Understanding the Block Model
Input array: [0, 1, 2, ..., 4095] (N = 4096)
BLOCK_SIZE = 1024
Program 0: processes elements [0:1024]
Program 1: processes elements [1024:2048]
Program 2: processes elements [2048:3072]
Program 3: processes elements [3072:4096]
Each program loads a block, computes, stores.
Key difference from CUDA: You think in blocks of data, not individual threads.
33.3 Memory Operations: Load and Store
33.3.1 Basic Loads
# Contiguous load
data = tl.load(ptr + offsets)
# Load with mask (for boundaries)
data = tl.load(ptr + offsets, mask=mask, other=0.0)
# Elements where mask=False get 0.033.3.2 Stores
# Contiguous store
tl.store(ptr + offsets, data)
# Store with mask
tl.store(ptr + offsets, data, mask=mask)
# Only stores where mask=True33.3.3 2D Memory Access
# Load a 2D block
BLOCK_M, BLOCK_N = 32, 32
row_offsets = tl.arange(0, BLOCK_M)[:, None] # Column vector
col_offsets = tl.arange(0, BLOCK_N)[None, :] # Row vector
offsets_2d = row_offsets * stride + col_offsets # 2D grid
data = tl.load(ptr + offsets_2d)33.4 Worked Example 1: Softmax
Let’s implement a numerically stable softmax.
Algorithm: 1. Find max (for numerical stability) 2. Compute exp(x - max) 3. Sum exponents 4. Divide
@triton.jit
def softmax_kernel(
output_ptr, input_ptr,
input_row_stride,
output_row_stride,
n_cols,
BLOCK_SIZE: tl.constexpr
):
# Each program handles one row
row_idx = tl.program_id(0)
# Compute offsets for this row
row_start = row_idx * input_row_stride
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = input_ptr + row_start + col_offsets
# Mask for this row
mask = col_offsets < n_cols
# Load input row
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Step 1: Find max
row_max = tl.max(row, axis=0)
# Step 2: Subtract max and exponentiate
numerator = tl.exp(row - row_max)
# Step 3: Sum
denominator = tl.sum(numerator, axis=0)
# Step 4: Divide
softmax_output = numerator / denominator
# Store result
output_row_start = row_idx * output_row_stride
output_ptrs = output_ptr + output_row_start + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)Launching:
def softmax(x):
n_rows, n_cols = x.shape
output = torch.empty_like(x)
BLOCK_SIZE = triton.next_power_of_2(n_cols)
grid = (n_rows,)
softmax_kernel[grid](
output, x,
x.stride(0), output.stride(0),
n_cols,
BLOCK_SIZE=BLOCK_SIZE
)
return outputPerformance: Within 5% of cuDNN softmax!
33.5 Worked Example 2: Fused Layer Norm
Layer norm: normalize each row to mean=0, std=1, then scale/shift.
Formula: \[\text{LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sigma} + \beta\]
Naive approach: 4 kernels (mean, variance, normalize, scale)
Fused approach: 1 kernel
@triton.jit
def layer_norm_kernel(
output_ptr, input_ptr, gamma_ptr, beta_ptr,
input_row_stride, output_row_stride,
n_cols, eps,
BLOCK_SIZE: tl.constexpr
):
row_idx = tl.program_id(0)
# Offsets
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
# Load input
input_ptrs = input_ptr + row_idx * input_row_stride + col_offsets
row = tl.load(input_ptrs, mask=mask, other=0.0)
# Compute mean
mean = tl.sum(row, axis=0) / n_cols
# Compute variance
variance = tl.sum((row - mean) * (row - mean), axis=0) / n_cols
# Normalize
rstd = 1.0 / tl.sqrt(variance + eps)
normalized = (row - mean) * rstd
# Load gamma and beta
gamma = tl.load(gamma_ptr + col_offsets, mask=mask)
beta = tl.load(beta_ptr + col_offsets, mask=mask)
# Scale and shift
output = gamma * normalized + beta
# Store
output_ptrs = output_ptr + row_idx * output_row_stride + col_offsets
tl.store(output_ptrs, output, mask=mask)Speedup over PyTorch: 2-3× (eliminates memory round-trips)
33.6 Worked Example 3: Fused Attention (Simplified)
Let’s implement a simplified FlashAttention (single-head, no masking).
Key ideas: 1. Tile Q into blocks (rows) 2. Tile K/V into blocks (columns) 3. For each Q block, iterate over K/V blocks 4. Accumulate attention output with online softmax
@triton.jit
def flash_attention_kernel(
Q, K, V, Out,
stride_qm, stride_qk,
stride_km, stride_kk,
stride_vm, stride_vk,
stride_om, stride_ok,
M, N, D,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_D: tl.constexpr,
):
# Each program handles one block of Q rows
pid = tl.program_id(0)
# Q block row offsets
offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, BLOCK_D)
# Load Q block
Q_block = tl.load(Q + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk)
# Initialize output accumulators
out_block = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32)
max_score = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
sum_exp = tl.zeros([BLOCK_M], dtype=tl.float32)
# Iterate over K/V blocks
for start_n in range(0, N, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
# Load K block
K_block = tl.load(K + offs_n[:, None] * stride_km + offs_d[None, :] * stride_kk)
# Compute attention scores: Q @ K.T
scores = tl.dot(Q_block, tl.trans(K_block)) # [BLOCK_M, BLOCK_N]
# Update running max
block_max = tl.max(scores, axis=1)
new_max = tl.maximum(max_score, block_max)
# Rescale old statistics
scale_old = tl.exp(max_score - new_max)
scale_new = tl.exp(block_max - new_max)
# Update softmax denominator
exp_scores = tl.exp(scores - new_max[:, None])
sum_exp = sum_exp * scale_old + tl.sum(exp_scores, axis=1) * scale_new
# Load V block
V_block = tl.load(V + offs_n[:, None] * stride_vm + offs_d[None, :] * stride_vk)
# Update output: weighted sum
out_block = out_block * scale_old[:, None] + tl.dot(exp_scores, V_block) * scale_new[:, None]
max_score = new_max
# Normalize
out_block = out_block / sum_exp[:, None]
# Store output
tl.store(Out + offs_m[:, None] * stride_om + offs_d[None, :] * stride_ok, out_block)This is the core of FlashAttention!
Key differences from standard attention: - Never materializes full attention matrix (O(n²)) - Uses online softmax updates (max, sum rescaling) - Processes in tiles that fit in shared memory
33.7 Matrix Multiply in Triton
Triton has a special tl.dot instruction that maps to tensor cores.
@triton.jit
def matmul_kernel(
A_ptr, B_ptr, C_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
# Compute offsets
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
# Initialize accumulator
acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
# Iterate over K dimension in blocks
for k in range(0, K, BLOCK_K):
# Load A block
a = tl.load(A_ptr + offs_m[:, None] * stride_am + (offs_k[None, :] + k) * stride_ak)
# Load B block
b = tl.load(B_ptr + (offs_k[:, None] + k) * stride_bk + offs_n[None, :] * stride_bn)
# Multiply-accumulate (uses tensor cores!)
acc += tl.dot(a, b)
# Store result
tl.store(C_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn, acc)Performance: 80-90% of cuBLAS on A100!
33.8 Optimization Techniques
33.8.1 1. Choose Block Sizes Wisely
Rule of thumb: - Power of 2: Always (16, 32, 64, 128, 256) - Multiple of warp size (32): Always - Tile fits in shared memory: Check limits
# A100: 164 KB shared memory per SM
BLOCK_M = 128
BLOCK_N = 128
BLOCK_K = 32
# FP16 data:
# A tile: 128 × 32 × 2 = 8 KB
# B tile: 32 × 128 × 2 = 8 KB
# C tile: 128 × 128 × 4 = 64 KB (FP32 accumulator)
# Total: 80 KB ✓ Fits!33.8.2 2. Use Compile-Time Constants
# BAD: Runtime constant
def kernel(..., block_size):
offsets = tl.arange(0, block_size) # Slow!
# GOOD: Compile-time constant
def kernel(..., BLOCK_SIZE: tl.constexpr):
offsets = tl.arange(0, BLOCK_SIZE) # Fast!Triton generates optimized code for each BLOCK_SIZE value.
33.8.3 3. Memory Coalescing
# BAD: Strided access
for i in range(N):
data = tl.load(ptr + i * large_stride)
# GOOD: Contiguous access
offsets = tl.arange(0, BLOCK)
data = tl.load(ptr + offsets)33.8.4 4. Reduce Synchronization
Triton handles synchronization automatically for tl.sum, tl.max, etc. But minimize reductions:
# BAD: Multiple reductions
max_val = tl.max(x)
sum_val = tl.sum(x)
# BETTER: Combined when possible
# (depends on algorithm)33.8.5 5. Profile with Nsight
# Generate profile
ncu --set full -o profile python script.py
# Look for:
# - SM efficiency (target: >80%)
# - Memory throughput (target: >70%)
# - Warp execution efficiency (target: >90%)33.8.6 6. Benchmark with triton.testing
Triton includes a robust benchmarking module that follows GPU benchmarking best practices:
import triton
# Simple timing (returns min time in ms by default)
ms = triton.testing.do_bench(lambda: kernel[grid](a, b, c))
# With all statistics
ms = triton.testing.do_bench(
lambda: kernel[grid](a, b, c),
warmup=25, # Warmup iterations
rep=100, # Timed repetitions
return_mode='median', # 'min', 'max', 'mean', 'median', 'all'
)For systematic benchmarking across problem sizes:
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['M', 'N'],
x_vals=[(512, 512), (1024, 1024), (2048, 2048)],
line_arg='provider',
line_vals=['triton', 'torch'],
line_names=['Triton', 'PyTorch'],
xlabel='Matrix Size',
ylabel='TFLOPS',
)
)
def benchmark_matmul(M, N, provider):
K = 512
a = torch.randn(M, K, device='cuda')
b = torch.randn(K, N, device='cuda')
if provider == 'triton':
fn = lambda: triton_matmul(a, b)
else:
fn = lambda: torch.mm(a, b)
ms = triton.testing.do_bench(fn)
flops = 2 * M * N * K
return flops / ms / 1e9 # TFLOPS
benchmark_matmul.run(print_data=True, save_path='benchmark.png')This automatically generates comparison plots. See Chapter 15: The Art of Measurement for a deeper discussion of GPU benchmarking methodology.
33.9 Debugging Triton Kernels
33.9.1 Print Debugging
@triton.jit
def kernel(...):
pid = tl.program_id(0)
# Debug print (only program 0)
if pid == 0:
tl.device_print("offsets:", offsets)33.9.2 Check for NaNs
result = tl.load(...)
tl.device_assert(tl.sum(tl.where(result != result, 1, 0)) == 0, "NaN detected!")33.9.3 Visualize Generated Code
# See LLVM IR
kernel.src # Shows intermediate representation
# See PTX assembly
kernel.asm # Shows GPU assembly33.10 When to Use Triton
Use Triton when: - You need fusion of operations (avoiding memory round-trips) - You’re implementing a custom algorithm (novel attention variants) - PyTorch native ops are too slow - You want cross-GPU portability
Don’t use Triton when: - PyTorch native ops are already fast enough - cuBLAS/cuDNN solve your problem (they’re heavily optimized) - You need the last 5-10% of performance (hand-tuned CUDA wins)
33.11 Performance Comparison
Softmax (1024 × 1024): - PyTorch native: 0.12 ms - Triton fused: 0.11 ms - Speedup: 1.1× (minimal overhead)
Layer Norm (4096 × 4096): - PyTorch native: 0.45 ms - Triton fused: 0.18 ms - Speedup: 2.5× (memory round-trip elimination)
FlashAttention (2048 seq, 64 dim): - PyTorch native: 8.2 ms - Triton FlashAttention: 1.1 ms - Speedup: 7.5× (tiling + fusion)
Matrix Multiply (4096 × 4096 × 4096): - cuBLAS: 3.2 ms - Triton: 3.8 ms - Slowdown: 1.2× (cuBLAS is better, use it!)
33.12 Common Pitfalls
33.12.1 1. Forgetting Masks
# BAD: No mask, reads beyond bounds
data = tl.load(ptr + offsets)
# GOOD: Mask for safety
mask = offsets < N
data = tl.load(ptr + offsets, mask=mask)33.12.2 2. Wrong Stride Computation
# If tensor is [M, N] row-major:
# stride(0) = N (number of elements to next row)
# stride(1) = 1 (number of elements to next column)
# BAD
offset = row * 1 + col * N
# GOOD
offset = row * stride_m + col * stride_n33.12.3 3. Block Size Mismatch
# Kernel expects BLOCK_SIZE = 1024
BLOCK_SIZE = 512 # Mismatch! Will failUse tl.constexpr to catch at compile time.
33.13 Triton 3.0 and Beyond
33.13.1 Multi-Backend Support
Triton now supports multiple GPU backends, not just NVIDIA:
# AMD GPUs via ROCm
TRITON_HIP=1 python my_triton_script.py
# Intel GPUs via XPU
TRITON_XPU=1 python my_triton_script.py
# Same kernel works on all!
@triton.jit
def my_kernel(...):
# Identical code
passSupported backends (Triton 3.x):
Backend Status Notes
─────────────────────────────────────────────────
NVIDIA (CUDA) Stable Primary, best optimized
AMD (ROCm) Stable MI300X, MI250X support
Intel (XPU) Experimental Max/Flex GPUs
CPU Experimental For debugging
33.13.2 Autotuning Improvements
Triton 3.0 has better autotuning infrastructure:
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=5, num_warps=2),
],
key=['M', 'N', 'K'],
prune_configs_by={
'early_config_prune': matmul_config_pruner, # Custom pruning
'perf_model': estimate_matmul_time, # Performance model
},
warmup=25,
rep=100,
)
@triton.jit
def matmul_kernel(...):
...New autotuning features: - num_stages: Pipeline depth (more overlapping) - num_warps: Warps per program (occupancy tuning) - Custom pruning to skip obviously bad configs - Performance models to guide search
33.13.3 Descriptor-Based Loads (Experimental)
Triton is adding TMA-like descriptor loads for Hopper:
@triton.jit
def kernel_with_descriptors(
A_ptr, B_ptr, C_ptr,
...
# Descriptor enables hardware-managed loads
A_desc: tl.tensor_descriptor,
):
# Hardware manages the load
a_tile = tl.load_tile(A_desc, [block_m_idx, block_k_idx])
# vs. traditional pointer-based
a_ptr_offset = A_ptr + offsets_m[:, None] * stride_am + offsets_k[None, :] * stride_ak
a_tile_ptr = tl.load(a_ptr_offset)Benefits: - Hardware handles address calculation - Better memory access patterns - Reduced register pressure
33.13.4 Persistent Kernels
For small operations, kernel launch overhead dominates. Persistent kernels help:
@triton.jit
def persistent_kernel(
A_ptr, B_ptr, C_ptr, D_ptr,
num_tasks,
BLOCK_SIZE: tl.constexpr,
):
# Get program ID
pid = tl.program_id(0)
# Persistent: each program processes multiple tiles
num_programs = tl.num_programs(0)
for task_id in range(pid, num_tasks, num_programs):
# Compute which tile this is
tile_m = task_id // num_tiles_n
tile_n = task_id % num_tiles_n
# Process tile
offs_m = tile_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs_n = tile_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
# Load, compute, store...Use case: Many small operations (MoE expert computations, sparse ops)
33.13.5 Integration with torch.compile
Triton kernels now integrate seamlessly with torch.compile:
import torch
import triton
import triton.language as tl
@triton.jit
def my_triton_kernel(...):
...
# Register as custom op
torch.library.custom_op("mylib::my_op", my_triton_function)
# Torch.compile will use it
@torch.compile
def my_model(x):
return torch.ops.mylib.my_op(x)
# Or use Triton directly via inductor
# torch.compile generates Triton for fused ops automatically!TorchInductor generates Triton:
# This PyTorch code...
y = torch.relu(x @ weight + bias)
# ...becomes this Triton (automatically!)
@triton.jit
def fused_matmul_bias_relu(...)
...33.13.6 Triton Distributed (Experimental)
For multi-GPU kernels with overlapped communication:
from triton.distributed import (
DistributedTensor,
all_gather,
reduce_scatter,
)
@triton.jit
def distributed_matmul_kernel(
A: DistributedTensor, # Sharded across GPUs
B: DistributedTensor,
C: DistributedTensor,
...
):
# All-gather A across GPUs
A_full = all_gather(A, dim=0)
# Local matmul
C_local = tl.dot(A_full, B)
# Reduce-scatter result
reduce_scatter(C_local, C, dim=0)Use case: Custom distributed ops that overlap compute and communication.
33.13.7 Best Practices for Triton 3.0
# 1. Use autotuning for performance-critical kernels
@triton.autotune(configs=[...], key=['M', 'N'])
@triton.jit
def kernel(...):
...
# 2. Specify precision explicitly
acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # FP32 accumulator
result = acc.to(tl.float16) # Convert at end
# 3. Use tl.dot for matrix ops (uses tensor cores)
c = tl.dot(a, b) # Good
c = a @ b # Also works but less explicit
# 4. Profile with triton.testing
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['M', 'N'],
x_vals=[128 * i for i in range(2, 33)],
...
)
)
def benchmark(M, N):
...33.14 Connections
Chapter 4 (GPU Architecture): Triton abstracts GPU details but understanding grids/blocks/shared memory helps.
Chapter 8 (Fusion): Triton enables fusion by writing multi-op kernels.
Chapter 10 (FlashAttention): Now you can implement it!
Chapter 19 (Profiling): Use Nsight to verify Triton kernel performance.
33.15 Key Takeaways
Triton makes GPU programming accessible: Python-like syntax, automatic optimization.
Think in blocks: Load a block, compute, store—not individual threads.
Fusion is the main benefit: Eliminate memory round-trips for multi-op sequences.
Use compile-time constants:
tl.constexprenables optimization.Multi-backend support: Same kernel runs on NVIDIA, AMD, and Intel GPUs.
Autotuning is essential: Use
@triton.autotunefor production kernels.torch.compile generates Triton: TorchInductor uses Triton under the hood.
Profile before claiming victory: Verify performance with Nsight.
Don’t replace optimized libraries: cuBLAS and cuDNN are still king for basic ops.
Week 1: Vector operations (add, mul, ReLU) Week 2: Reductions (softmax, layer norm) Week 3: Matrix multiply basics Week 4: Fused operations (bias + activation, etc.) Week 5: FlashAttention implementation
By week 5, you’ll be writing production-quality custom kernels.
33.16 Further Reading
- Triton Documentation
- OpenAI Triton Introduction
- Triton Tutorials
- Triton GitHub (backends, examples)
- FlashAttention Triton Implementation
- Triton Distributed Paper - Multi-GPU kernel programming