FlashAttention & Online Softmax
The math trick that lets attention be tiled instead of materialized.
FlashAttention is the most influential inference kernel of the last few years — dramatically faster and more memory-efficient, with no change to the model. The whole thing rests on the online softmax trick.
The problem
Standard softmax over a vector x of length N:
softmax(x_i) = exp(x_i) / Σ_j exp(x_j)The denominator sums over all elements. To compute softmax for any single element, you need every element first. But if you process attention scores in blocks because the full N × N matrix won't fit in fast on-chip memory, you only see one block at a time. How do you normalize?
Numerical stability comes first
exp(1000) is infinity in floating point. The fix is the safe softmax: subtract the max from every element before exponentiating.
m = max(x_1, ..., x_N)
softmax(x_i) = exp(x_i − m) / Σ_j exp(x_j − m)Mathematically identical, numerically safe. But still needs two passes (find max, then sum). Still requires seeing everything.
The online trick
Maintain a running max m and a running sum ℓ. Correct them when a new block arrives.
After block 1: m₁ = max of block 1, ℓ₁ = Σ exp(x_j − m₁) over block 1.
Block 2 arrives. Local max m₂. If m₂ > m₁, your old ℓ₁ was normalized against the wrong (smaller) max. Rescale:
m_new = max(m₁, m₂)
ℓ_new = ℓ₁ · exp(m₁ − m_new) + Σ exp(x_j − m_new) (over block 2)This works because:
exp(x_j − m₁) · exp(m₁ − m_new) = exp(x_j − m_new)The factor exp(m₁ − m_new) retroactively rebases the old sum against the new max. If m₁ was already bigger, the factor is exp(0) = 1 and nothing changes.
Extending to attention output
In FlashAttention you're computing softmax(QKᵀ) · V. Along with m and ℓ, maintain a running output O. When a new block raises the max, rescale O the same way:
O_new = O₁ · exp(m₁ − m_new) + (new block's contribution using m_new)At the end, divide O by the final ℓ. The result is mathematically identical to standard attention. No approximation, just a smarter order of operations.
What you store
Per row of the output you carry forward three things between blocks: m (scalar), ℓ (scalar), O (vector of length d). That's O(d) state per row — total O(N · d). Crucially, the N × N attention matrix is never materialized in HBM. That's where the memory savings come from.
Punchline
The online softmax trick turns a fundamentally global operation (must see all N values) into a streaming one. Process blocks one at a time, with O(1) state per row. Without it, tiling attention would either be approximate or require an extra memory pass — defeating the point.
A beautiful example of rearranging math to match hardware constraints rather than fighting them.