Paper breakdown
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
Tri Dao et al. · 2022 · NeurIPS 2022
Tiles the attention computation so that softmax(QK^T)V is computed block-by-block in fast on-chip SRAM, never materializing the full n-by-n attention matrix in slow HBM. Exact, not approximate. 2-4x training speedup on long sequences with no loss change.
Overview
Dao et al. (2022) re-implemented self-attention as a fused, IO-aware GPU kernel. Standard attention computes , applies softmax row-wise to get , then computes . Each of these steps writes its full output to GPU high-bandwidth memory (HBM), then reads it back for the next step. For sequence length and heads, the materialized attention matrix is several gigabytes per layer; the bottleneck on modern GPUs is not FLOPs but the bandwidth between HBM and on-chip SRAM.
FlashAttention computes the same exact result by tiling: load a block of and a block of into SRAM, compute the local contribution to softmax and to , write only the partial output back to HBM, repeat for the next block, and update the running softmax statistics so the result is mathematically identical to the non-tiled version. The intermediate matrix is never materialized.
The reported wall-clock improvement on a typical training step is 2–4×, with strict equality on the output (the kernel is exact, not approximate, unlike linear-attention proposals). FlashAttention is now the default kernel in PyTorch (scaled_dot_product_attention), CUDA cuDNN, JAX, and every serious training framework. The paper's contribution is at the systems-level — a GPU implementation, not an architectural change — but its impact on the practicality of long-context transformers is foundational.
Mathematical Contributions
The standard attention computation
Given , attention is:
where . The naive implementation:
- Load from HBM, compute , write to HBM. Cost: FLOPs, HBM bytes written.
- Load from HBM, compute , write to HBM. Cost: FLOPs, bytes read and written.
- Load from HBM, compute , write . Cost: FLOPs, bytes read.
The arithmetic intensity (FLOPs per byte of HBM traffic) is , which is well below the GPU's compute-to-bandwidth ratio for typical or . The kernel is bandwidth-bound, so achieved throughput is far below peak FLOPs.
Tiling and the online softmax
The challenge is that softmax requires the full row of to compute the normalization . The paper uses an online (Welford-style) update: process in column blocks of size , maintain a running max and a running normalizer per query row, and accumulate the partial .
For a block of column-block , define block max row-wise, block exp , block normalizer . Update the running statistics:
The two exponential factors and are , so the rescaling is numerically stable. The running output at the final block equals the un-tiled exact answer.
Block sizes
Block sizes (for rows) and (for columns) are chosen so that the working set fits in SRAM. On A100 with 192 KB SRAM per streaming multiprocessor and , the paper uses . This means bytes per element class times the precision (4-byte fp32 working). Total HBM traffic is bytes for the inner matrix multiplies, an order of magnitude less than naive attention.
Recompute in the backward pass
The standard backward pass needs to compute the gradient of , the gradient of to backprop into . Storing requires HBM, which defeats the point. FlashAttention stores only the output and the per-row softmax statistics , then recomputes on the fly in the backward pass from , , . This trades extra FLOPs for less HBM traffic; on bandwidth-bound kernels it is a net win.
The recompute is exact (not gradient checkpointing, which is an approximation when interacting with stochastic ops). The forward and backward together are bit-for-bit equivalent to the naive implementation, with the same gradients.
Asymptotic and wall-clock complexity
FLOPs: , same as naive attention. HBM accesses: where is SRAM size, vs. for naive. For typical , , , this is less HBM traffic. Empirically the wall-clock speedup on A100 training is 2–4× depending on sequence length, since some overhead is unavoidable.
The memory footprint for activations during training drops from to . This is what makes 32k-token training feasible on a single GPU; the standard implementation runs out of memory on a single A100 around .
What FlashAttention does not do
It does not change the architecture. The model with FlashAttention computes exactly the same function as the model without it. The paper is not in the linear-attention or sparse-attention families; the attention pattern is dense and the matrix product is exact.
It does not help with the parameter count or with the autoregressive inference latency at small batch and short prefix; the speedup is largest for long sequences where HBM bandwidth dominates.
It does not help on hardware without fast on-chip SRAM organized as a shared cache. The paper targets NVIDIA A100; the same idea has since been ported to H100 (FlashAttention-2, FlashAttention-3) and to other hardware (AMD MI300, Apple Metal).
Connections to TheoremPath Topics
- Flash attention — modern treatment including FlashAttention-2/3 and the H100 warp-specialized kernel.
- Attention mechanism theory — the mathematical operation FlashAttention computes exactly.
- Attention variants and efficiency — comparison with Reformer, Performer, Linformer (approximate) vs FlashAttention (exact).
- Softmax and numerical stability — the max-subtract trick generalizes to the online softmax used here.
- KV cache — what the inference-time analog of FlashAttention works against.
- Sparse attention and long context — the alternative path: change the attention pattern so the matrix is structurally sparse.
- Computer architecture for ML — the GPU memory hierarchy that motivates IO-aware kernels.
Why It Matters Now
FlashAttention reframed the question. Before 2022, "attention is " was usually invoked as a FLOP cost; many papers proposed approximate-attention methods (Linformer, Reformer, Performer, Big Bird) to reduce that cost. FlashAttention showed that the FLOPs were not the binding constraint on real hardware — the binding constraint was memory bandwidth, and a smart kernel could fix it without changing the math.
The empirical picture since then: every public LLM training framework defaults to FlashAttention or an equivalent kernel. Approximate-attention methods have largely lost the architectural argument; they are still used in some long-context settings (Mamba and selective-SSM line of work, Ring Attention for very long contexts) but for canonical transformer training the exact dense-attention kernel won.
The systems lesson is broader. The paper is one of the cleanest examples of an algorithmic improvement coming from the hardware side rather than the math side. The arithmetic was already known (the online softmax dates to Milakov & Gimelshein, 2018, in a different context). The contribution is the engineering — the correct tile sizes, the correct kernel layout, the correct trade between recompute and store, all calibrated to actual A100 specs.
This kind of work is increasingly the bottleneck. Modern transformer training is co-designed with the GPU memory hierarchy: FlashAttention for attention, FSDP for parameters, micro-batching for activations, paged attention for KV-cache at inference. The next round of efficiency wins will come from systems-level redesign, not new architectural ideas.
References
Canonical:
- Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Ré, C. (2022). "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness." NeurIPS. arXiv:2205.14135.
- Dao, T. (2024). "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning." ICLR. arXiv:2307.08691. 2× speedup over FlashAttention-1; better warp-level work partition.
- Shah, J., Bikshandi, G., Zhang, Y., Thakkar, V., Ramani, P., & Dao, T. (2024). "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision." NeurIPS. arXiv:2407.08608. H100-specific; warp-specialized; FP8 support.
Direct precursors:
- Milakov, M., & Gimelshein, N. (2018). "Online normalizer calculation for softmax." arXiv:1805.02867. The online-softmax recurrence FlashAttention generalizes.
- Rabe, M. N., & Staats, C. (2021). "Self-attention Does Not Need O(n²) Memory." arXiv:2112.05682. -memory recursive self-attention; precursor in the "attention without materializing the matrix" line.
Approximate-attention alternatives (the prior dominant approach):
- Wang, S., Li, B. Z., Khabsa, M., Fang, H., & Ma, H. (2020). "Linformer: Self-Attention with Linear Complexity." arXiv:2006.04768.
- Choromanski, K. et al. (2021). "Rethinking Attention with Performers." ICLR. arXiv:2009.14794.
- Kitaev, N., Kaiser, Ł., & Levskaya, A. (2020). "Reformer: The Efficient Transformer." ICLR. arXiv:2001.04451.
Inference-time follow-ons:
- Kwon, W. et al. (2023). "Efficient Memory Management for Large Language Model Serving with PagedAttention." SOSP. arXiv:2309.06180. vLLM. The KV-cache analog.
- Liu, H., Zaharia, M., & Abbeel, P. (2024). "Ring Attention with Blockwise Transformers for Near-Infinite Context." ICLR. arXiv:2310.01889. Distributes the FlashAttention computation across devices for million-token contexts.
Standard textbook:
- Prince, S. J. D. (2023). Understanding Deep Learning. MIT Press. Section 12.4 — attention efficiency.
Connected topics
Last reviewed: May 5, 2026