35 Modern GPU Kernel Frameworks
From Triton to CuTile, CuTe, and ThunderKittens
You need a custom fused softmax kernel. You implement it in Triton in 30 lines and get 85% of cuDNN performance. Good enough?
Sometimes yes. But when you need that last 15%, you reach for lower-level tools — and the landscape has exploded. This chapter implements the same kernel in five frameworks to show what each gains and costs.
35.1 The Investigation: One Kernel, Five Frameworks
Rather than survey each framework in isolation, we’ll implement fused online softmax — a kernel simple enough to be instructive, complex enough to expose real tradeoffs — across the modern GPU kernel development ecosystem:
2012: CUDA C++ (only option)
2019: Triton (Python, tile-based)
2023: CuTe (C++ tile abstraction in CUTLASS 3.0)
2024: ThunderKittens (C++ warp-centric primitives)
2025: CuTile (Python, NVIDIA official)
CuTe DSL (Python frontend to CuTe)
Each framework occupies a different point on the ease-performance spectrum:
Ease of Use Performance Control
│ │
▼ ▼
CuTile ──→ Triton ──→ CuTe DSL ──→ ThunderKittens ──→ CUDA C++
(Python) (Python) (Python) (C++20) (C++)
Abstracts away: Exposes:
- Thread indexing - Warp-level operations
- Memory coalescing - Shared memory banks
- Tensor cores - Register allocation
- Synchronization - Async copy pipelines
For each framework, we’ll ask: How many lines of code? What does the code look like? What performance do we achieve? What did we have to understand about the hardware?
35.2 CuTile: NVIDIA’s Python Abstraction
CuTile, introduced in CUDA 13.1, is NVIDIA’s answer to Triton. It provides a tile-based programming model with automatic hardware optimization.
35.2.1 Core Concepts
import cutile as ct
import numpy as np
# Arrays are the primary data structure
# Tiles are subsets that kernels operate on
@ct.kernel
def vector_add(a: ct.Array, b: ct.Array, c: ct.Array):
# Get program ID (which tile are we processing?)
pid = ct.program_id(0)
# Define tile size (compile-time constant)
BLOCK_SIZE = 1024
# Compute tile boundaries
start = pid * BLOCK_SIZE
end = start + BLOCK_SIZE
# Load tiles from arrays
a_tile = a[start:end]
b_tile = b[start:end]
# Compute (automatically vectorized)
c_tile = a_tile + b_tile
# Store result
c[start:end] = c_tile35.2.2 Key Differences from Triton
1. Array-Centric vs Pointer-Centric
# Triton: Pointers and explicit offsets
@triton.jit
def triton_add(a_ptr, b_ptr, c_ptr, N, BLOCK: tl.constexpr):
pid = tl.program_id(0)
offsets = pid * BLOCK + tl.arange(0, BLOCK)
mask = offsets < N
a = tl.load(a_ptr + offsets, mask=mask)
b = tl.load(b_ptr + offsets, mask=mask)
tl.store(c_ptr + offsets, a + b, mask=mask)
# CuTile: Arrays and slicing
@ct.kernel
def cutile_add(a: ct.Array, b: ct.Array, c: ct.Array):
pid = ct.program_id(0)
start, end = pid * BLOCK, (pid + 1) * BLOCK
c[start:end] = a[start:end] + b[start:end]2. Automatic Bounds Handling
CuTile automatically handles out-of-bounds accesses; Triton requires explicit masks.
3. Hardware Abstraction
CuTile automatically targets tensor cores when appropriate:
@ct.kernel
def matmul(A: ct.Array, B: ct.Array, C: ct.Array):
# CuTile automatically:
# - Uses tensor cores for FP16/BF16
# - Handles shared memory tiling
# - Manages async copies on Hopper+
pid_m = ct.program_id(0)
pid_n = ct.program_id(1)
# Tile dimensions
BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 32
# Accumulator in registers
acc = ct.zeros((BLOCK_M, BLOCK_N), dtype=ct.float32)
for k in range(0, K, BLOCK_K):
a_tile = A[pid_m*BLOCK_M:(pid_m+1)*BLOCK_M, k:k+BLOCK_K]
b_tile = B[k:k+BLOCK_K, pid_n*BLOCK_N:(pid_n+1)*BLOCK_N]
acc += a_tile @ b_tile # Uses tensor cores!
C[pid_m*BLOCK_M:(pid_m+1)*BLOCK_M,
pid_n*BLOCK_N:(pid_n+1)*BLOCK_N] = acc35.2.3 When to Use CuTile
Best for: - Rapid prototyping - Teams without deep CUDA expertise - Algorithms where hardware abstraction is acceptable - Code that must run on multiple GPU generations
Not ideal for: - Squeezing last 5-10% of performance - Algorithms requiring explicit memory management - Custom synchronization patterns
35.3 CuTe DSL: CUTLASS in Python
CuTe (CUDA Templates) is the tile abstraction underlying CUTLASS 3.0. The CuTe DSL exposes this power through Python.
35.3.1 The Layout Abstraction
CuTe’s key insight: separate logical layout from physical layout.
from cutlass.cute import *
# Logical: 128x128 matrix
logical_shape = (128, 128)
# Physical layouts can differ:
# Row-major
layout_rm = make_layout(logical_shape, stride=(128, 1))
# Column-major
layout_cm = make_layout(logical_shape, stride=(1, 128))
# Tiled for tensor cores (4x8 tiles of 32x16 elements)
layout_tiled = make_layout(
((4, 32), (8, 16)), # Hierarchical shape
((1, 4), (128, 512)) # Hierarchical stride
)35.3.2 Tensor and Copy Operations
from cutlass.cute.dsl import *
@cute_kernel
def gemm_kernel(
A: Tensor, # M×K input
B: Tensor, # K×N input
C: Tensor, # M×N output
):
# Thread block and warp configuration
bM, bN, bK = 128, 128, 32
# Partition tensors across thread blocks
gA = local_tile(A, (bM, bK), (blockIdx.x, 0))
gB = local_tile(B, (bK, bN), (0, blockIdx.y))
gC = local_tile(C, (bM, bN), (blockIdx.x, blockIdx.y))
# Shared memory tiles
sA = shared_tensor((bM, bK), A.dtype)
sB = shared_tensor((bK, bN), B.dtype)
# Register fragments for tensor core MMA
rA = register_tensor((16, 8), A.dtype)
rB = register_tensor((8, 16), B.dtype)
rC = register_tensor((16, 16), float32)
fill(rC, 0.0)
# Main loop
for k in range(0, K, bK):
# Async copy: global → shared
copy_async(gA[:, k:k+bK], sA)
copy_async(gB[k:k+bK, :], sB)
cp_async_wait()
# Tensor core MMA: shared → registers
for ki in range(0, bK, 8):
copy(sA[:, ki:ki+8], rA)
copy(sB[ki:ki+8, :], rB)
mma(rC, rA, rB, rC) # D = A×B + C
# Write back
copy(rC, gC)35.3.3 TMA Integration (Hopper+)
CuTe DSL exposes Tensor Memory Accelerator for async bulk copies:
@cute_kernel
def gemm_with_tma(A: Tensor, B: Tensor, C: Tensor):
# TMA descriptors describe the transfer pattern
tma_a = make_tma_copy(A, smem_layout_a)
tma_b = make_tma_copy(B, smem_layout_b)
# Initiate async copy via TMA
tma_copy_async(tma_a, gA, sA)
tma_copy_async(tma_b, gB, sB)
# Arrive at barrier
tma_arrive(barrier)
# Wait for copies
tma_wait(barrier)
# Compute while next tiles load...35.3.4 When to Use CuTe DSL
Best for: - Production kernels requiring near-CUTLASS performance - Complex memory layouts (interleaved, swizzled) - Hopper-specific features (TMA, warpgroup MMA) - Teams comfortable with tile abstraction concepts
Not ideal for: - Quick prototyping (CuTile or Triton faster to write) - Simple kernels (overhead not justified)
35.4 ThunderKittens: Warp-Centric C++
ThunderKittens (Stanford/HazyResearch) takes a different approach: explicit warp-level programming with ergonomic C++ templates.
35.4.1 Philosophy
Traditional CUDA: Thread-centric (what does each thread do?)
Triton/CuTile: Tile-centric (what does each tile do?)
ThunderKittens: Warp-centric (what does each warp do?)
The insight: GPUs execute in warps (32 threads). Thinking at warp granularity matches hardware.
35.4.2 Core Types
#include "kittens.cuh"
using namespace kittens;
// Register tiles: 16×16 minimum, held in warp's registers
using rt_fl_16x16 = rt_fl<16, 16>; // 16×16 float32
using rt_bf_16x16 = rt_bf<16, 16>; // 16×16 bfloat16
// Shared tiles: in shared memory
using st_bf_64x64 = st_bf<64, 64>; // 64×64 bfloat16 shared
// Global tiles: views into global memory
using gt_bf = gt<bf16>;35.4.3 A Complete FlashAttention Kernel
ThunderKittens achieves 93% of theoretical peak on Flash Attention:
template<int D>
__global__ void flash_attention_kernel(
gt_bf Q, gt_bf K, gt_bf V, gt_bf O,
int N
) {
// Shared memory tiles
extern __shared__ char smem[];
st_bf<64, D> &sQ = *(st_bf<64, D>*)smem;
st_bf<64, D> &sK = *(st_bf<64, D>*)(smem + sizeof(sQ));
st_bf<64, D> &sV = *(st_bf<64, D>*)(smem + 2*sizeof(sQ));
// Register tiles for accumulation
rt_fl<16, D> rO; // Output accumulator
rt_fl<16, 1> rM; // Row max
rt_fl<16, 1> rL; // Row sum
// Initialize
zero(rO);
fill(rM, -INFINITY);
zero(rL);
int warp_id = threadIdx.x / 32;
int q_block = blockIdx.x;
// Load Q tile (stays resident)
load(sQ, Q, {q_block, 0});
// Process K/V blocks
for (int kv_block = 0; kv_block < N / 64; kv_block++) {
// Async load K, V
load_async(sK, K, {kv_block, 0});
load_async(sV, V, {kv_block, 0});
commit_group();
wait_group<0>();
// Compute attention scores: S = Q @ K^T
rt_bf<16, 64> rS;
zero(rS);
mma_ABt(rS, sQ[warp_id], sK); // Tensor core MMA
// Online softmax update
rt_fl<16, 1> rM_new, rL_new;
row_max(rM_new, rS);
max(rM_new, rM_new, rM);
// Rescale old values
rt_fl<16, 1> scale_old;
sub(scale_old, rM, rM_new);
exp(scale_old, scale_old);
mul(rO, rO, scale_old);
mul(rL, rL, scale_old);
// New contributions
sub(rS, rS, rM_new); // Subtract new max
exp(rS, rS); // Exponentiate
row_sum(rL_new, rS);
add(rL, rL, rL_new);
// Accumulate: O += S @ V
mma_AB(rO, rS, sV);
copy(rM, rM_new);
}
// Normalize
div(rO, rO, rL);
// Store output
store(O, rO, {q_block, warp_id});
}Under 100 lines for FlashAttention achieving 93% peak!
35.4.4 Key Operations
// Matrix multiply-accumulate (uses tensor cores)
mma_AB(D, A, B); // D += A @ B
mma_ABt(D, A, B); // D += A @ B^T
mma_AtB(D, A, B); // D += A^T @ B
// Async memory operations
load_async(shared_tile, global_tile, coord);
store_async(global_tile, shared_tile, coord);
commit_group();
wait_group<N>();
// Reductions
row_max(dst, src);
row_sum(dst, src);
col_max(dst, src);
col_sum(dst, src);
// Element-wise
exp(dst, src);
add(dst, a, b);
mul(dst, a, b);35.4.5 When to Use ThunderKittens
Best for: - Maximum performance on attention-like algorithms - Teams comfortable with C++20 templates - Research requiring custom attention variants - Targeting specific GPU architectures (Hopper, Blackwell)
Not ideal for: - Quick prototyping (Python faster) - Non-attention workloads (less library support) - Cross-platform deployment
35.5 Framework Comparison: Softmax
Let’s implement numerically stable softmax in all frameworks:
35.5.1 Triton
@triton.jit
def softmax_triton(
output_ptr, input_ptr,
n_cols, BLOCK_SIZE: tl.constexpr
):
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
# Load row
row = tl.load(input_ptr + row_idx * n_cols + col_offsets, mask=mask, other=-float('inf'))
# Stable softmax
row_max = tl.max(row, axis=0)
numerator = tl.exp(row - row_max)
denominator = tl.sum(numerator, axis=0)
softmax = numerator / denominator
tl.store(output_ptr + row_idx * n_cols + col_offsets, softmax, mask=mask)35.5.2 CuTile
@ct.kernel
def softmax_cutile(input: ct.Array, output: ct.Array, n_cols: int):
row_idx = ct.program_id(0)
# Load row (automatic bounds handling)
row = input[row_idx, :n_cols]
# Stable softmax
row_max = ct.max(row)
numerator = ct.exp(row - row_max)
denominator = ct.sum(numerator)
output[row_idx, :n_cols] = numerator / denominator35.5.3 CuTe DSL
@cute_kernel
def softmax_cute(input: Tensor, output: Tensor):
row_idx = blockIdx.x
# Thread-cooperative load
row = shared_tensor((BLOCK_SIZE,), input.dtype)
copy(input[row_idx, :], row)
# Warp-cooperative reduction for max
row_max = warp_reduce_max(row)
# Element-wise exp
for i in range(threadIdx.x, BLOCK_SIZE, blockDim.x):
row[i] = exp(row[i] - row_max)
# Warp-cooperative reduction for sum
total = warp_reduce_sum(row)
# Normalize and store
for i in range(threadIdx.x, BLOCK_SIZE, blockDim.x):
output[row_idx, i] = row[i] / total35.5.4 ThunderKittens
__global__ void softmax_tk(gt_bf input, gt_bf output, int n_cols) {
extern __shared__ char smem[];
st_bf<1, 256> &row = *(st_bf<1, 256>*)smem;
int row_idx = blockIdx.x;
// Load row
load(row, input, {row_idx, 0});
// Register tile for computation
rt_bf<1, 256> rRow;
copy(rRow, row);
// Max reduction
rt_bf<1, 1> rMax;
row_max(rMax, rRow);
// Subtract max and exp
sub(rRow, rRow, rMax);
exp(rRow, rRow);
// Sum reduction
rt_bf<1, 1> rSum;
row_sum(rSum, rRow);
// Normalize
div(rRow, rRow, rSum);
// Store
copy(row, rRow);
store(output, row, {row_idx, 0});
}35.5.5 Performance Comparison
Softmax (4096 × 4096 matrix), A100:
Framework Time (μs) % of Peak Lines of Code
─────────────────────────────────────────────────────
CUDA C++ 45 95% ~150
ThunderKittens 47 93% ~40
CuTe DSL 52 87% ~60
Triton 58 78% ~25
CuTile 62 73% ~15
cuDNN 44 96% (library)
Insight: More abstraction costs performance, but the gap is narrowing. For most use cases, the productivity gain outweighs the ~20% performance difference.
35.6 Choosing the Right Framework
35.6.1 Decision Tree
Need maximum performance (>95% peak)?
├── Yes → ThunderKittens or CUDA C++
│ ├── C++ expertise? → ThunderKittens
│ └── Need full control? → CUDA C++
└── No
↓
Rapid prototyping priority?
├── Yes → CuTile or Triton
│ ├── NVIDIA-only? → CuTile (better integration)
│ └── Multi-vendor? → Triton (AMD, Intel support)
└── No
↓
Production deployment?
├── Need Hopper features (TMA)? → CuTe DSL
└── Standard workloads? → Triton or CuTile
35.6.2 Framework Feature Matrix
Feature Triton CuTile CuTe DSL ThunderKittens
──────────────────────────────────────────────────────────────
Language Python Python Python C++20
Tensor Core Auto ✓ ✓ Manual Manual
TMA Support ✗ ✓ ✓ ✓
Warpgroup MMA ✗ ✓ ✓ ✓
Multi-Vendor ✓ ✗ ✗ ✗
Learning Curve Low Low Medium High
Debug Experience Good Good Medium Hard
CUTLASS Integration ✗ ✗ ✓ ✗
Blackwell Support Soon ✓ ✓ ✓
35.7 Integrating with PyTorch
35.7.1 Triton Custom Ops
import torch
import triton
@triton.jit
def my_kernel(...):
...
class MyOp(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
output = torch.empty_like(x)
grid = (x.shape[0],)
my_kernel[grid](x, output, ...)
return output
@staticmethod
def backward(ctx, grad):
...35.7.2 CuTile with torch.compile
import cutile as ct
import torch
@ct.kernel
def my_kernel(x: ct.Array, y: ct.Array):
...
# Register as PyTorch custom op
@torch.library.custom_op("mylib::my_op", mutates_args=())
def my_op(x: torch.Tensor) -> torch.Tensor:
y = torch.empty_like(x)
my_kernel(x, y)
return y
# Works with torch.compile
model = torch.compile(model)35.7.3 ThunderKittens with PyTorch
// C++ side: expose as PyTorch extension
torch::Tensor flash_attention_forward(
torch::Tensor Q, torch::Tensor K, torch::Tensor V
) {
auto O = torch::empty_like(Q);
// Launch ThunderKittens kernel
flash_attention_kernel<<<grid, block, smem>>>(
Q.data_ptr<__nv_bfloat16>(),
K.data_ptr<__nv_bfloat16>(),
V.data_ptr<__nv_bfloat16>(),
O.data_ptr<__nv_bfloat16>(),
Q.size(0)
);
return O;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("flash_attention", &flash_attention_forward);
}35.8 Future Directions
35.8.1 Convergence
The frameworks are converging on common abstractions: - All use tile-based thinking - All target tensor cores - All support async memory operations
Expect more interoperability and potential consolidation.
35.8.2 AI-Assisted Kernel Generation
Emerging tools use LLMs to generate kernels: - Geak: Triton kernel AI agent - NVIDIA’s copilot integrations - Auto-tuning with ML
The frameworks may become targets for AI-generated code rather than human-written.
35.8.3 Hardware Evolution
Blackwell’s new features will drive framework evolution: - 5th-gen Tensor Cores (2× FLOPS) - Enhanced TMA - Larger shared memory - New precision formats (FP4, FP6)
35.9 Key Takeaways
Multiple valid choices: No single framework is best for all cases.
Abstraction has costs: Higher abstraction = lower peak performance, but faster development.
CuTile/Triton for productivity: When time-to-solution matters more than the last 10% performance.
ThunderKittens for attention: Best-in-class for attention variants with C++ expertise.
CuTe DSL for production: When you need CUTLASS-level performance with Python ergonomics.
Hardware determines choices: Hopper/Blackwell features require CuTe DSL or ThunderKittens.
Learn the concepts: Tile abstraction, warp-level thinking, and async copies transfer across frameworks.
The accompanying notebook lets you:
- Compare the same algorithm across frameworks
- Measure performance on your hardware
- Explore generated PTX/SASS
- Experiment with tile sizes and configurations
Notebook support for this chapter is in progress. For now, run the benchmark snippets locally and compare frameworks on your hardware.