Attention at the sequence lengths people actually care about is not compute-bound; it is bandwidth-bound. The two matmuls in O=softmax(QK/d)VO = \text{softmax}(QK^\top / \sqrt d)\,V are cheap relative to a modern GPU’s tensor-core throughput, but the standard implementation writes the full N×NN \times N score matrix to high-bandwidth memory, reads it back to softmax it, and reads it again to multiply by VV. FlashAttention’s contribution is an I/O accounting result, not a new approximation: it computes the same output while never materializing that matrix in HBM. The cost is real, and naming it is the point of this piece — you pay extra arithmetic (online rescaling on the forward, a full recompute on the backward) and a steep jump in kernel-engineering effort to keep the working set on-chip.

We follow the original paper — Dao, Fu, Ermon, Rudra, and Ré, FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (arXiv:2205.14135) — and the reference CUDA implementation at Dao-AILab/flash-attention (BSD-3-Clause). Our recreation is a minimal Triton forward plus a recompute backward, run on a single A100-class GPU. It is illustrative, not a benchmark sweep, and we are explicit below about which claims that scope can and cannot support.

The accounting trick: attention is I/O-bound, not FLOP-bound

Count the work. The two matmuls cost roughly 4N2d4 N^2 d FLOPs; on an A100 the tensor cores chew through that at hundreds of TFLOP/s. The problem is everything around them. Standard attention reads and writes a matrix of N2N^2 entries several times: once to store S=QKS = QK^\top, again for the softmax pass, again as the probabilities PP feed the second matmul. For a single head at N=8192N = 8192, that matrix is roughly 67M entries — tens of megabytes per head, far past what fits in on-chip SRAM, so every touch is an HBM round-trip.

That round-trip is the bottleneck because the operations gating it have low arithmetic intensity. HBM bandwidth on an A100 is roughly 1.5 to 2.0 TB/s; on-chip SRAM bandwidth is an order of magnitude higher but the capacity is only a few hundred KB per SM. Softmax and the score read/write are memory-movement, not matmul, so the kernel stalls waiting on HBM while the tensor cores idle. The paper formalizes this: standard attention performs Θ(Nd+N2)\Theta(N d + N^2) HBM accesses, dominated by the N2N^2 term, while a tiled, fused kernel performs Θ(N2d2/M)\Theta(N^2 d^2 / M) accesses where MM is the SRAM size. Because d2/Md^2 / M is far below 1 for typical head dimensions (64 to 128) and SRAM sizes, that is a large reduction in traffic at unchanged FLOPs. The speedup is bought in bytes moved, not operations performed.

The memory footprint is the other half. By never storing SS or PP, FlashAttention keeps the materialized state linear in NN — you hold the query tile, a streaming key/value tile, an output accumulator, and two scalars per row. That is what lets attention run at 64K context where the quadratic version simply runs out of memory.

Streaming softmax, computed exactly

The obstacle to tiling is the softmax denominator. A naive tiled loop over key blocks cannot normalize each block independently, because softmax needs the maximum and the sum over the entire row for both numerical stability and correctness. The resolution is the online-softmax recurrence: compute the row max and normalizer in a single streaming pass, retroactively correcting the partial results whenever a new block raises the running max.

Maintain three running quantities per query row: the max m, the denominator l, and the output accumulator O. For each key/value tile, the update is:

streaming_update.py (one key/value tile, per query row)
s = (q @ k.T) * sm_scale # block scores, [BLOCK_M, BLOCK_N]
m_new = maximum(m, rowmax(s)) # new running max
p = exp(s - m_new) # rescaled to the new max
alpha = exp(m - m_new) # correction factor for old state
l = l * alpha + rowsum(p) # fix the denominator
O = O * alpha[:, None] + p @ v # fix the accumulator, then add this tile
m = m_new
# after the last tile:
O = O / l[:, None]

The alpha factor is the whole trick. When a later tile contains a larger score, everything accumulated so far was exponentiated against a stale, smaller max; multiplying the old denominator and the old output by alpha = exp(m_old - m_new) retro-corrects them exactly. Stress this for skeptical readers: there is no approximation here. In exact arithmetic the final normalized O is bit-identical to running a full-row softmax and one matmul. The only deviation in practice is floating-point reassociation — the operations happen in a different order — which lands well inside fp16 rounding noise. FlashAttention is exact attention; it belongs in the same conceptual bucket as a fused kernel, not the bucket with linear-attention or low-rank approximations.

A minimal forward in Triton

The forward kernel assigns one program to each query tile. That tile is loaded once and stays resident; the inner loop streams key/value tiles through SRAM, applying the update above. At the end we write the output and the log-sum-exp L = m + log(l), a single scalar per row that the backward needs.

flash_fwd.py (minimal; batch/head pointer offsets elided)
import triton
import triton.language as tl
@triton.jit
def flash_fwd(Q, K, V, L, Out, sm_scale, N_CTX,
stride_m, stride_d,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, D: tl.constexpr):
pid = tl.program_id(0)
offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, D)
q = tl.load(Q + offs_m[:, None] * stride_m + offs_d[None, :] * stride_d)
m = tl.full([BLOCK_M], float("-inf"), tl.float32)
l = tl.zeros([BLOCK_M], tl.float32)
acc = tl.zeros([BLOCK_M, D], tl.float32)
for start_n in range(0, N_CTX, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
k = tl.load(K + offs_n[:, None] * stride_m + offs_d[None, :] * stride_d)
v = tl.load(V + offs_n[:, None] * stride_m + offs_d[None, :] * stride_d)
s = tl.dot(q, tl.trans(k)) * sm_scale # [BLOCK_M, BLOCK_N]
m_new = tl.maximum(m, tl.max(s, axis=1))
p = tl.exp(s - m_new[:, None])
alpha = tl.exp(m - m_new)
l = l * alpha + tl.sum(p, axis=1)
acc = acc * alpha[:, None] + tl.dot(p.to(v.dtype), v)
m = m_new
acc = acc / l[:, None]
tl.store(L + offs_m, m + tl.log(l)) # logsumexp for backward
tl.store(Out + offs_m[:, None] * stride_m + offs_d[None, :] * stride_d, acc)

This is deliberately stripped down. It omits causal masking, the boundary handling for sequence lengths that are not a multiple of BLOCK_N, dropout, variable-length packing, and the batch/head dimension in the pointer math — all of which the production kernel handles with predicated loads and masks. BLOCK_M and BLOCK_N are typically 64 or 128, chosen so the query tile, the streaming key/value tile, and the accumulator co-reside in SRAM; in real code you wrap this in triton.autotune over num_warps and num_stages.

Validation is where you earn the word “exact.” Compare against a plain reference and assert closeness at fp16 tolerance:

validate.py
import torch
def ref_attention(q, k, v, scale):
s = (q @ k.transpose(-2, -1)) * scale
return torch.softmax(s, dim=-1) @ v
o_flash = flash_attn(q, k, v) # our Triton kernel
o_ref = ref_attention(q, k, v, scale)
torch.testing.assert_close(o_flash, o_ref, atol=1e-2, rtol=0) # fp16 noise

If you run the reference in fp32 you can tighten atol by an order of magnitude; the residual is reassociation, not algorithm error. This test is the whole reason a verification-minded reader should trust the kernel before trusting its timings.

The backward pass: recompute, don’t store

The backward is where the design choice bites, and where most from-scratch attempts quietly cheat. The gradient needs the probabilities PP. The naive move is to save PP (or even SS) from the forward and reload it — which re-materializes the N×NN \times N matrix and throws away everything the forward bought you. FlashAttention instead stores nothing quadratic and recomputes the scores tile-by-tile in the backward, reconstructing the exact probabilities from QQ, KK, and the saved log-sum-exp:

flash_bwd.py (math sketch, per tile; L is the saved logsumexp)
S = (Q @ K.T) * sm_scale # recomputed; never stored globally
P = torch.exp(S - L[:, None]) # exact softmax, since L = m + log(l)
dV += P.T @ dO
dP = dO @ V.T
D = (dO * O).sum(-1, keepdim=True) # softmax-Jacobian collapses to this scalar
dS = P * (dP - D) * sm_scale
dQ += dS @ K
dK += dS.T @ Q

The one piece of cleverness is D. The softmax Jacobian would naively require a full matrix-vector contraction per row, but for the composition softmax-then-matmul it collapses to a single per-row scalar, D_i = rowsum(dO_i * O_i). That makes dS = P * (dP - D) a cheap elementwise operation and keeps the backward tileable with the same memory profile as the forward.

This is gradient checkpointing, specialized and pushed into the kernel. You trade compute for memory: the backward redoes the forward’s score matmul, so its FLOP count is meaningfully higher than a hypothetical store-everything backward. But the stored activations drop from Θ(N2)\Theta(N^2) to Θ(N)\Theta(N), and because the kernel remains memory-bound, the recomputed matmuls largely hide behind the bandwidth you are no longer spending on HBM round-trips. The forward saves only OO and the per-row L; everything else is rebuilt on demand.

What reproduces, and what doesn’t

Our scope is minimal by design, and the matched claims are partial. Two things reproduce cleanly and are the load-bearing claims of the paper:

  • Exactness. The Triton output matches the reference within fp16 noise, as the validation above asserts. This is the claim that distinguishes FlashAttention from an approximation, and it holds.
  • Memory scaling. Sweep the sequence length and watch the allocator. The naive baseline’s footprint grows with N2N^2 and OOMs; the tiled kernel’s footprint stays flat in the attention buffers, growing only with the linear terms. The shape of the memory curve — the actual point of the technique — is recovered by a tutorial-grade kernel.

What does not reproduce from a minimal kernel is the reference’s absolute throughput. The FlashAttention paper (Dao et al., 2022) reports a roughly 3x end-to-end speedup training GPT-2 at sequence length 1K, a 15% gain over the MLPerf 1.1 BERT-large training record, and memory linear in sequence length. Those are end-to-end, hardware-specific, and the product of years of kernel tuning. Our partial runs are consistent with the direction — the tiled kernel pulls ahead of naive PyTorch as NN grows — but the absolute gap to flash-attn stays large, and we have not run a controlled sweep to put a defensible number on it.

MetricReportedReproduced
Numerical agreement vs reference attentionExact, algebraically identical (FlashAttention 2022)Max abs error within fp16 noise, ~1e-2; tighter in fp32 (illustrative)
Attention memory footprint vs seq lengthLinear in N (FlashAttention paper)Recovered: footprint flat where the naive baseline OOMs (qualitative)
End-to-end GPT-2 training speedup, seq 1K~3x vs baseline (FlashAttention paper)Out of scope, minimal kernel only
Attention-op speedup over naive PyTorchLarge, grows with seq length (hardware-dependent)In progress; direction consistent, absolute gap large
TableReported figures are the publicly stated paper numbers (sources named in prose). Reproduced entries reflect a minimal single-GPU Triton recreation and are qualitative or in progress, not a benchmarked claim.

The honest read: a tutorial reproduces the algorithm and its asymptotics, not the engineering artifact. That distinction is the whole reason the next section exists.

From tutorial to FA2/FA3: the last several x is engineering

The gap between our kernel and the reference is not conceptual; it is occupancy, scheduling, and precision. The follow-up papers spell out exactly where it lives.

FlashAttention-2 (Dao, arXiv:2307.08691) attacks three inefficiencies in the original. It cuts the non-matmul FLOPs — the rescaling work that runs on the slow CUDA cores rather than the tensor cores — by deferring the output rescaling to the end of the loop. It parallelizes across the query/sequence dimension, not just batch and heads, so long-context single-sequence workloads actually fill the GPU. And it re-partitions work across warps (splitting along Q rather than K) to slash shared-memory reads and writes. The FA2 paper reports roughly 2x over the original and reaching 50 to 73 percent of theoretical peak matmul throughput on A100 — numbers that are A100-specific and that our minimal kernel does not approach.

FlashAttention-3 (Shah, Bikshandi, Dao, et al., arXiv:2407.08608) is a Hopper exercise and is where Triton-level tutorials structurally fall short. It leans on warp specialization (dedicated producer and consumer warps), TMA for asynchronous bulk copies, and a software pipeline that overlaps the WGMMA matmuls with the softmax of the previous tile — hiding the non-matmul work under the matmuls instead of serializing them. It adds FP8 with block quantization and incoherent processing to recover the accuracy that low precision would otherwise cost. The FA3 paper reports roughly 1.5 to 2x over FA2, reaching about 75 percent of H100 FP16 peak and approaching 1.2 PFLOP/s with FP8. These are architecture-locked results; the asynchrony primitives they depend on do not even exist on pre-Hopper hardware.

The takeaway for anyone planning to “just write FlashAttention”: Triton hands you the algorithm, the exactness, and the memory-scaling for free. The remaining several-x — register-level work partitioning, deep async pipelining, FP8 accuracy tricks, CUTLASS-grade occupancy tuning against one specific SM — is hand engineering, and it is most of the value in the shipped library. Reproducing the idea is a weekend; reproducing the artifact is the reason flash-attn is a dependency and not a snippet.

The bottom line

FlashAttention is the canonical lesson in a principle that generalizes well past attention: when a kernel is memory-bound, you restructure it for the memory hierarchy, not the FLOP count, and you can often do so without changing a single bit of the math. Three things to keep:

  • Tiling plus online softmax gives a linear memory footprint and cuts HBM traffic by roughly d2/Md^2 / M, with no approximation. The output is exact to floating-point reassociation; validate that first.
  • The win is bandwidth, not arithmetic. The kernel does more FLOPs — rescaling on the forward, recompute on the backward — and is faster anyway, because attention was never compute-bound at these sizes.
  • A minimal Triton kernel recovers the scaling behavior; it does not recover the reference’s absolute speed. That gap is warp specialization, async copies, and low-precision engineering — real work that a tutorial, by construction, cannot reproduce.

If you build this, ship the validation test alongside the kernel. The exactness check is what converts “I implemented something attention-shaped” into “I implemented attention,” and on a verification publication that is the only claim worth making.