TensorLoop
Fundamentals

FlashAttention & Online Softmax

The math trick that lets attention be tiled instead of materialized.

FlashAttention is the most influential inference kernel of the last few years — dramatically faster and more memory-efficient, with no change to the model. The whole thing rests on the online softmax trick.

The problem

Standard softmax over a vector x of length N:

softmax(x_i) = exp(x_i) / Σ_j exp(x_j)

The denominator sums over all elements. To compute softmax for any single element, you need every element first. But if you process attention scores in blocks because the full N × N matrix won't fit in fast on-chip memory, you only see one block at a time. How do you normalize?

Numerical stability comes first

exp(1000) is infinity in floating point. The fix is the safe softmax: subtract the max from every element before exponentiating.

m = max(x_1, ..., x_N)
softmax(x_i) = exp(x_i − m) / Σ_j exp(x_j − m)

Mathematically identical, numerically safe. But still needs two passes (find max, then sum). Still requires seeing everything.

The online trick

Maintain a running max m and a running sum . Correct them when a new block arrives.

After block 1: m₁ = max of block 1, ℓ₁ = Σ exp(x_j − m₁) over block 1.

Block 2 arrives. Local max m₂. If m₂ > m₁, your old ℓ₁ was normalized against the wrong (smaller) max. Rescale:

m_new = max(m₁, m₂)
ℓ_new = ℓ₁ · exp(m₁ − m_new) + Σ exp(x_j − m_new)   (over block 2)

This works because:

exp(x_j − m₁) · exp(m₁ − m_new) = exp(x_j − m_new)

The factor exp(m₁ − m_new) retroactively rebases the old sum against the new max. If m₁ was already bigger, the factor is exp(0) = 1 and nothing changes.

Extending to attention output

In FlashAttention you're computing softmax(QKᵀ) · V. Along with m and , maintain a running output O. When a new block raises the max, rescale O the same way:

O_new = O₁ · exp(m₁ − m_new) + (new block's contribution using m_new)

At the end, divide O by the final . The result is mathematically identical to standard attention. No approximation, just a smarter order of operations.

What you store

Per row of the output you carry forward three things between blocks: m (scalar), (scalar), O (vector of length d). That's O(d) state per row — total O(N · d). Crucially, the N × N attention matrix is never materialized in HBM. That's where the memory savings come from.

Punchline

The online softmax trick turns a fundamentally global operation (must see all N values) into a streaming one. Process blocks one at a time, with O(1) state per row. Without it, tiling attention would either be approximate or require an extra memory pass — defeating the point.

A beautiful example of rearranging math to match hardware constraints rather than fighting them.

On this page