Dissecting ThunderKittens: Anatomy of a Compact DSL for High-Performance AI Kernels

Introduction

Modern ML workloads depend heavily on custom GPU kernels. Even when a model is expressed as clean tensor operations, the performance almost always comes from specialized implementations underneath. Good examples of this are the many different attention mechanisms, GEMMs across different precisions, and MoE-style grouped GEMMs, which have become a fairly common architectural choice in state-of-the-art models.

This matters a lot if we look at it from the perspective of scaling laws. Better models have generally come from some mix of better algorithms, more data, and more compute. If we want to keep pushing that forward, we care not just about algorithmic quality, but also about how efficiently those algorithms actually run on hardware. One clean way to frame it is:

IntelligenceDollar=IntelligenceFLOPS×FLOPSDollar\frac{\text{Intelligence}}{\text{Dollar}} = \frac{\text{Intelligence}}{\text{FLOPS}} \times \frac{\text{FLOPS}}{\text{Dollar}}

We want to improve both terms. On the algorithm side, researchers need to iterate quickly on new architectures and new training/inference recipes. On the hardware side, it must translate into code that runs close to the metal. There is a persistent tension here: we want programming environments productive enough for research, but close enough to the metal to get serious performance.

ThunderKittens is a C++ embedded DSL from Stanford’s Hazy Research group designed to square this circle. It wraps CUDA in a tile-centric API that maps directly to hardware primitives, hiding most of the boilerplate while sacrificing almost none of the performance.


The GPU Memory Hierarchy

Before diving into ThunderKittens’ abstractions, it helps to build intuition for the hardware they sit on. The H100 SXM5 has a pronounced memory hierarchy with very different bandwidth and latency characteristics at each level. Misaligning data movement with this hierarchy is the single biggest cause of underperforming kernels.

GM
Global Memory (HBM)
Off-chip high-bandwidth memory. Shared across all SMs. ~700 GB/s on H100.
40–80 GB
~400 cycles
SM
Shared Memory (SMEM)
On-chip scratchpad per SM. Manually managed tiling buffer. ~12 TB/s.
228 KB/SM
~20 cycles
RF
Register File
Per-thread private registers. Fastest on-chip storage; compiler-allocated.
255 regs/thread
~1 cycle
TC
Tensor Cores (WGMMA)
Warp-group matrix units. Execute 64×N×16 (fp16) MMA in 4 cycles.
~1979 TFLOPS
4 cycles/tile
Figure: GPU memory hierarchy for NVIDIA Hopper (H100). Each level has different bandwidth, latency, and capacity tradeoffs. Click a level to highlight it.

The central challenge in kernel design is keeping the Tensor Cores fed. They are by far the highest-throughput compute unit in the chip, but they consume data faster than any single memory subsystem can supply it in isolation. The only way to maintain high MMA utilization is to stage data cleverly through the hierarchy.


Bandwidth Is the Bottleneck

The bandwidth gap between SMEM and HBM is enormous — roughly 13× on the H100. This is not a new insight, but it is easy to underestimate how much it drives kernel design.

Memory Bandwidth Comparison GB/s — H100 SXM5
12000 9000 6000 3000 0 SMEM L1/L2 HBM NVLink 12,288 6,000 3,350 900
Figure: Memory subsystem bandwidth hierarchy on the H100 SXM5. Values are approximate and represent peak theoretical bandwidth. Note the 13× gap between SMEM and HBM.

When a GEMM kernel achieves high efficiency, it is usually because it is spending most of its time doing MMA instructions on data that is already resident in SMEM or registers — touching HBM as infrequently as possible. The key metric is arithmetic intensity: how many FLOP/byte of HBM traffic does the kernel perform?

I=FLOPsBytes from HBMI = \frac{\text{FLOPs}}{\text{Bytes from HBM}}

For a block-tiled GEMM with tile size Bs×BsB_s \times B_s:

Itile=2Bs32Bs2sizeof(element)=Bssizeof(element)I_{\text{tile}} = \frac{2 B_s^3}{2 B_s^2 \cdot \text{sizeof(element)}} = \frac{B_s}{\text{sizeof(element)}}

So a bf16 GEMM with Bs=128B_s = 128 has arithmetic intensity I=64I = 64 FLOP/byte. The roofline intersection for the H100 is around 68 FLOP/byte Roofline = peak TFLOPS / peak HBM bandwidth = 1979 / 3.35×10³ ≈ 590 FLOP/byte for fp16 MMA vs HBM. The compute roof is hit well before HBM saturation for large-enough tiles. , meaning any tile larger than 128 should be compute-bound rather than memory-bound for a well-written kernel.


ThunderKittens’ Tile System

ThunderKittens is built around the idea that everything should be expressed in terms of tiles that map cleanly onto the GPU hierarchy. At its most basic, TK uses a base tile with a fixed height of 16, while the width depends on the datatype:

DatatypeBase tileHardware unit
fp16 / bf1616×161 MMA fragment
fp816×321 MMA fragment
fp3216×81 accumulator frag

These dimensions are not arbitrary — they match the warp-level MMA fragment sizes exposed by wmma / PTX mma.sync / Hopper’s new wgmma.mma_async. Encoding them in the type system at compile time means the compiler can verify layout compatibility, bank-conflict freedom, and register allocation before a single line of CUDA runs.

Register Tiles

Register tiles (rt_<dtype><rows><cols>) live in the register file and are operated on directly by MMA instructions:

// Declare a 64×64 bf16 register tile (4×4 base tiles)
rt_bf16<4, 4> reg_A, reg_B;

// Declare a 64×64 fp32 accumulator tile
rt_fl<4, 4> accum;
rt_zero(accum);   // zero-initialize via CUDA register writes

Shared Memory Tiles

Shared memory tiles (st_<dtype><rows><cols>) wrap pointers into SMEM with compile-time layout information. TK enforces the right swizzling patterns to guarantee zero bank conflicts:

extern __shared__ char smem[];

// Allocate SMEM tiles back-to-back
st_bf16<4, 4> &As = *(st_bf16<4, 4>*)smem;
st_bf16<4, 4> &Bs = *(st_bf16<4, 4>*)(smem + sizeof(st_bf16<4, 4>));

Global Layouts

Global memory layouts (gl<dtype, ...>) describe how a tensor in HBM is structured. TK uses them to compute the correct byte offsets and issue coalesced loads/stores via cp.async (Ampere) or the Hopper TMA engine:

// Bind a 2D bf16 matrix to a global layout descriptor
gl<bf16, 1, 1, -1, -1> A_layout(A_ptr, nullptr, nullptr, M, K);

A Complete Tiled GEMM

Putting it all together, here is a minimal block-tiled GEMM using ThunderKittens. This kernel achieves competitive utilization on large matrices without any explicit warp-level synchronization boilerplate:

#include "thunderkittens.cuh"
using namespace kittens;

constexpr int BM = 64, BN = 64, BK = 16;

__global__ void tk_gemm_kernel(
    gl<bf16, 1, 1, -1, -1> A,
    gl<bf16, 1, 1, -1, -1> B,
    gl<float, 1, 1, -1, -1> C,
    int M, int N, int K
) {
    // ── Shared memory staging buffers ──────────────────────
    extern __shared__ char smem[];
    st_bf16<BM/16, BK/16> &As = *(st_bf16<BM/16, BK/16>*)smem;
    st_bf16<BK/16, BN/16> &Bs = *(st_bf16<BK/16, BN/16>*)(smem + sizeof(As));

    // ── Register accumulator (fp32 for numerical stability) ─
    rt_fl<BM/16, BN/16> acc;
    rt_zero(acc);

    // ── Output tile coordinates ─────────────────────────────
    int bm = blockIdx.y, bn = blockIdx.x;

    // ── Main reduction loop ─────────────────────────────────
    for (int bk = 0; bk < K / BK; bk++) {
        // Issue async copies from HBM → SMEM (Hopper TMA)
        load(As, A, {bm, bk});
        load(Bs, B, {bk, bn});
        __syncthreads();

        // Matrix multiply-accumulate (WGMMA on Hopper)
        mma(acc, As, Bs, acc);
        __syncthreads();
    }

    // ── Write result back to HBM ────────────────────────────
    store(C, acc, {bm, bn});
}

The comments show exactly which hardware operation TK abstracts: load()cp.async / TMA, mma()wgmma.mma_async, store() → normal global stores. On Hopper, TK can pipeline the load() and mma() calls using the new commit_group / wait_group barrier API, doubling effective HBM throughput by overlapping data fetching with computation.


Kernel Occupancy and Register Pressure

One subtle cost of TK’s approach is that large register tiles reduce warp occupancy — the number of warps that can co-reside on an SM. A 64×64 fp32 accumulator tile uses:

registers=64×64/2=2048 registers×4 bytes=8 KB/warp\text{registers} = 64 \times 64 / 2 = 2048 \text{ registers} \times 4 \text{ bytes} = 8 \text{ KB/warp}

The H100 SM has 65,536 32-bit registers total. With 8 KB = 2,048 registers per warp, we can fit at most 32 warps per SM in theory, but in practice SMEM allocations and other state further constrain occupancy.

graph LR
    O[Occupancy] -->|limits| H[Latency Hiding]
    H -->|limits| U[MMA Utilization]
    R[Register Pressure] -->|reduces| O
    S[SMEM Footprint] -->|reduces| O
    T[Tile Size] -->|increases| R
    T -->|increases| S
    T -->|increases| U
    style O fill:#AD2111,stroke:#333,stroke-width:1px,color:#fff
    style U fill:#8ec07c,stroke:#333,stroke-width:1px,color:#000
    style T fill:#fabd2f,stroke:#333,stroke-width:1px,color:#000

The optimal tile size balances these three tensions. ThunderKittens’ type system makes it easy to quickly iterate over tile shapes without manually rewriting index arithmetic every time.


Conclusion

ThunderKittens sits in a sweet spot between raw CUDA and higher-level libraries like cuBLAS or Triton. It gives you direct control over the memory hierarchy and hardware primitives while making the most error-prone parts — layout calculations, swizzling, MMA fragment packing — a compile-time concern rather than a runtime debugging session.

For researchers who need to write custom kernels but do not want to spend a week getting bank conflicts and warp synchronization right, it is a compelling option. The 16×16 tile abstraction is restrictive by design, but that restriction buys you correctness and performance that is hard to achieve with ad-hoc CUDA.

The source code is open and the DSL is actively developed. The original paper by Spector et al. (2024) is worth reading for the formal treatment of the type system and benchmarks across attention, GEMM, and convolution workloads.

Citation

Please cite this work as:

Achille, "Dissecting ThunderKittens: Anatomy of a Compact DSL for High-Performance AI Kernels",
Achille Triomphe, May 2026.

Or use the BibTeX citation:

@article{achille2026dissectingthunderkittens,
  author    = {Achille},
  title     = {Dissecting ThunderKittens: Anatomy of a Compact DSL for High-Performance AI Kernels},
  journal   = {Achille Triomphe},
  year      = {2026},
  month     = {May},
  note      = {https://www.achilletriomphe.com/blog/dissecting-thunderkittens/},
}