Skip to main content

Paper breakdown

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

Tri Dao et al. · 2022 · NeurIPS 2022

Tiles the attention computation so that softmax(QK^T)V is computed block-by-block in fast on-chip SRAM, never materializing the full n-by-n attention matrix in slow HBM. Exact, not approximate. 2-4x training speedup on long sequences with no loss change.

Overview

Dao et al. (2022) re-implemented self-attention as a fused, IO-aware GPU kernel. Standard attention computes S=QK/dkRn×nS = QK^\top / \sqrt{d_k} \in \mathbb{R}^{n \times n}, applies softmax row-wise to get P=softmax(S)P = \text{softmax}(S), then computes O=PVO = PV. Each of these steps writes its full n×nn \times n output to GPU high-bandwidth memory (HBM), then reads it back for the next step. For sequence length n=8192n = 8192 and h=16h = 16 heads, the materialized attention matrix is several gigabytes per layer; the bottleneck on modern GPUs is not FLOPs but the bandwidth between HBM and on-chip SRAM.

FlashAttention computes the same exact result by tiling: load a block of QQ and a block of K,VK, V into SRAM, compute the local contribution to softmax and to OO, write only the partial output back to HBM, repeat for the next K,VK, V block, and update the running softmax statistics so the result is mathematically identical to the non-tiled version. The intermediate n×nn \times n matrix is never materialized.

The reported wall-clock improvement on a typical training step is 2–4×, with strict equality on the output (the kernel is exact, not approximate, unlike linear-attention proposals). FlashAttention is now the default kernel in PyTorch (scaled_dot_product_attention), CUDA cuDNN, JAX, and every serious training framework. The paper's contribution is at the systems-level — a GPU implementation, not an architectural change — but its impact on the practicality of long-context transformers is foundational.

Mathematical Contributions

The standard attention computation

Given Q,K,VRn×dQ, K, V \in \mathbb{R}^{n \times d}, attention is:

S=QKd,Pij=exp(Sij)kexp(Sik),O=PVS = \frac{Q K^\top}{\sqrt{d}}, \qquad P_{ij} = \frac{\exp(S_{ij})}{\sum_{k} \exp(S_{ik})}, \qquad O = P V

where S,PRn×nS, P \in \mathbb{R}^{n \times n}. The naive implementation:

  1. Load Q,KQ, K from HBM, compute SS, write SS to HBM. Cost: O(n2d)O(n^2 d) FLOPs, O(n2)O(n^2) HBM bytes written.
  2. Load SS from HBM, compute PP, write PP to HBM. Cost: O(n2)O(n^2) FLOPs, O(n2)O(n^2) bytes read and written.
  3. Load P,VP, V from HBM, compute OO, write OO. Cost: O(n2d)O(n^2 d) FLOPs, O(n2)O(n^2) bytes read.

The arithmetic intensity (FLOPs per byte of HBM traffic) is O(d)O(d), which is well below the GPU's compute-to-bandwidth ratio for typical d=64d = 64 or d=128d = 128. The kernel is bandwidth-bound, so achieved throughput is far below peak FLOPs.

Tiling and the online softmax

The challenge is that softmax requires the full row of SS to compute the normalization kexp(Sik)\sum_k \exp(S_{ik}). The paper uses an online (Welford-style) update: process K,VK, V in column blocks of size BcB_c, maintain a running max mm and a running normalizer \ell per query row, and accumulate the partial OO.

For a block S(j)\mathbf{S}^{(j)} of column-block jj, define block max m~(j)=max(S(j))\tilde m^{(j)} = \max(\mathbf{S}^{(j)}) row-wise, block exp P~(j)=exp(S(j)m~(j))\tilde P^{(j)} = \exp(\mathbf{S}^{(j)} - \tilde m^{(j)}), block normalizer ~(j)=rowsum(P~(j))\tilde \ell^{(j)} = \text{rowsum}(\tilde P^{(j)}). Update the running statistics:

m(j)=max(m(j1),m~(j))m^{(j)} = \max(m^{(j-1)}, \tilde m^{(j)}) (j)=em(j1)m(j)(j1)+em~(j)m(j)~(j)\ell^{(j)} = e^{m^{(j-1)} - m^{(j)}}\, \ell^{(j-1)} + e^{\tilde m^{(j)} - m^{(j)}}\, \tilde \ell^{(j)} O(j)=1(j) ⁣[(j1)em(j1)m(j)O(j1)+em~(j)m(j)P~(j)V(j)]O^{(j)} = \frac{1}{\ell^{(j)}}\!\left[\ell^{(j-1)} e^{m^{(j-1)} - m^{(j)}} O^{(j-1)} + e^{\tilde m^{(j)} - m^{(j)}} \tilde P^{(j)} V^{(j)}\right]

The two exponential factors em(j1)m(j)e^{m^{(j-1)} - m^{(j)}} and em~(j)m(j)e^{\tilde m^{(j)} - m^{(j)}} are 1\le 1, so the rescaling is numerically stable. The running output O(j)O^{(j)} at the final block equals the un-tiled exact answer.

Block sizes

Block sizes BrB_r (for QQ rows) and BcB_c (for K,VK, V columns) are chosen so that the working set Brd+Bcd+BrBcB_r \cdot d + B_c \cdot d + B_r \cdot B_c fits in SRAM. On A100 with 192 KB SRAM per streaming multiprocessor and d=128d = 128, the paper uses Br=Bc=64B_r = B_c = 64. This means 12864+12864+642=20480128 \cdot 64 + 128 \cdot 64 + 64^2 = 20480 bytes per element class times the precision (4-byte fp32 working). Total HBM traffic is O(n2d/B)O(n^2 d / B) bytes for the inner matrix multiplies, an order of magnitude less than naive attention.

Recompute in the backward pass

The standard backward pass needs PP to compute the gradient of VV, the gradient of SS to backprop into Q,KQ, K. Storing PP requires O(n2)O(n^2) HBM, which defeats the point. FlashAttention stores only the output OO and the per-row softmax statistics (m,)(m, \ell), then recomputes PP on the fly in the backward pass from Q,KQ, K, mm, \ell. This trades extra FLOPs for less HBM traffic; on bandwidth-bound kernels it is a net win.

The recompute is exact (not gradient checkpointing, which is an approximation when interacting with stochastic ops). The forward and backward together are bit-for-bit equivalent to the naive implementation, with the same gradients.

Asymptotic and wall-clock complexity

FLOPs: O(n2d)O(n^2 d), same as naive attention. HBM accesses: O(n2d2/M)O(n^2 d^2 / M) where MM is SRAM size, vs. O(nd+n2)O(n d + n^2) for naive. For typical d=128d = 128, M=192KBM = 192\text{KB}, n=4096n = 4096, this is 5×\sim 5\times less HBM traffic. Empirically the wall-clock speedup on A100 training is 2–4× depending on sequence length, since some overhead is unavoidable.

The memory footprint for activations during training drops from O(n2)O(n^2) to O(n)O(n). This is what makes 32k-token training feasible on a single GPU; the standard implementation runs out of memory on a single A100 around n=16384n = 16384.

What FlashAttention does not do

It does not change the architecture. The model with FlashAttention computes exactly the same function as the model without it. The paper is not in the linear-attention or sparse-attention families; the attention pattern is dense and the matrix product is exact.

It does not help with the parameter count or with the autoregressive inference latency at small batch and short prefix; the speedup is largest for long sequences where HBM bandwidth dominates.

It does not help on hardware without fast on-chip SRAM organized as a shared cache. The paper targets NVIDIA A100; the same idea has since been ported to H100 (FlashAttention-2, FlashAttention-3) and to other hardware (AMD MI300, Apple Metal).

Connections to TheoremPath Topics

Why It Matters Now

FlashAttention reframed the question. Before 2022, "attention is O(n2)O(n^2)" was usually invoked as a FLOP cost; many papers proposed approximate-attention methods (Linformer, Reformer, Performer, Big Bird) to reduce that cost. FlashAttention showed that the FLOPs were not the binding constraint on real hardware — the binding constraint was memory bandwidth, and a smart kernel could fix it without changing the math.

The empirical picture since then: every public LLM training framework defaults to FlashAttention or an equivalent kernel. Approximate-attention methods have largely lost the architectural argument; they are still used in some long-context settings (Mamba and selective-SSM line of work, Ring Attention for very long contexts) but for canonical transformer training the exact dense-attention kernel won.

The systems lesson is broader. The paper is one of the cleanest examples of an algorithmic improvement coming from the hardware side rather than the math side. The arithmetic was already known (the online softmax dates to Milakov & Gimelshein, 2018, in a different context). The contribution is the engineering — the correct tile sizes, the correct kernel layout, the correct trade between recompute and store, all calibrated to actual A100 specs.

This kind of work is increasingly the bottleneck. Modern transformer training is co-designed with the GPU memory hierarchy: FlashAttention for attention, FSDP for parameters, micro-batching for activations, paged attention for KV-cache at inference. The next round of efficiency wins will come from systems-level redesign, not new architectural ideas.

References

Canonical:

  • Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Ré, C. (2022). "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness." NeurIPS. arXiv:2205.14135.
  • Dao, T. (2024). "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning." ICLR. arXiv:2307.08691. 2× speedup over FlashAttention-1; better warp-level work partition.
  • Shah, J., Bikshandi, G., Zhang, Y., Thakkar, V., Ramani, P., & Dao, T. (2024). "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision." NeurIPS. arXiv:2407.08608. H100-specific; warp-specialized; FP8 support.

Direct precursors:

  • Milakov, M., & Gimelshein, N. (2018). "Online normalizer calculation for softmax." arXiv:1805.02867. The online-softmax recurrence FlashAttention generalizes.
  • Rabe, M. N., & Staats, C. (2021). "Self-attention Does Not Need O(n²) Memory." arXiv:2112.05682. O(logn)O(\log n)-memory recursive self-attention; precursor in the "attention without materializing the matrix" line.

Approximate-attention alternatives (the prior dominant approach):

  • Wang, S., Li, B. Z., Khabsa, M., Fang, H., & Ma, H. (2020). "Linformer: Self-Attention with Linear Complexity." arXiv:2006.04768.
  • Choromanski, K. et al. (2021). "Rethinking Attention with Performers." ICLR. arXiv:2009.14794.
  • Kitaev, N., Kaiser, Ł., & Levskaya, A. (2020). "Reformer: The Efficient Transformer." ICLR. arXiv:2001.04451.

Inference-time follow-ons:

  • Kwon, W. et al. (2023). "Efficient Memory Management for Large Language Model Serving with PagedAttention." SOSP. arXiv:2309.06180. vLLM. The KV-cache analog.
  • Liu, H., Zaharia, M., & Abbeel, P. (2024). "Ring Attention with Blockwise Transformers for Near-Infinite Context." ICLR. arXiv:2310.01889. Distributes the FlashAttention computation across devices for million-token contexts.

Standard textbook:

  • Prince, S. J. D. (2023). Understanding Deep Learning. MIT Press. Section 12.4 — attention efficiency.

Connected topics

Last reviewed: May 5, 2026