Skip to main content

LLM Construction

Attention Variants and Efficiency

Multi-head, multi-query, grouped-query, linear, and sparse attention: how each variant trades expressivity for efficiency, and when to use which.

AdvancedTier 2CurrentSupporting~55 min

Why This Matters

Standard scaled dot-product attention has O(n2d)O(n^2 d) time complexity where nn is sequence length. Naive implementations also materialize the n×nn \times n attention matrix in HBM, costing O(n2)O(n^2) memory; tiled implementations like FlashAttention (Dao et al. 2022) preserve the O(n2d)O(n^2 d) time count but reduce HBM traffic to O(n)O(n), so in practice the binding constraint on long sequences is typically memory bandwidth rather than FLOPs. For long sequences (documents, code, genomics), this remains the bottleneck. Different attention variants trade expressivity for efficiency, and the right choice depends on the deployment scenario: training throughput, inference latency, KV cache size, or sequence length.

Understanding these variants requires knowing exactly what each one changes and what it preserves.

Baseline: Multi-Head Attention (MHA)

Definition

Multi-Head Attention

Given HH attention heads, each head hh has its own projection matrices WhQ,WhK,WhVRd×dkW_h^Q, W_h^K, W_h^V \in \mathbb{R}^{d \times d_k} where dk=d/Hd_k = d/H:

headh=Attention(XWhQ,XWhK,XWhV)=softmax(XWhQ(XWhK)dk)XWhV\text{head}_h = \text{Attention}(XW_h^Q, XW_h^K, XW_h^V) = \text{softmax}\left(\frac{XW_h^Q (XW_h^K)^\top}{\sqrt{d_k}}\right) XW_h^V

MHA(X)=Concat(head1,,headH)WO\text{MHA}(X) = \text{Concat}(\text{head}_1, \ldots, \text{head}_H) W^O

where WORd×dW^O \in \mathbb{R}^{d \times d} is the output projection.

Parameter count per layer: 4d24d^2 (three projections WQ,WK,WVW^Q, W^K, W^V plus output WOW^O, each d×dd \times d).

KV cache per token: 2Hdk=2d2 \cdot H \cdot d_k = 2d values must be stored for each token in the context. For a model with LL layers, the total KV cache per token is 2Ld2Ld.

Proposition

Multi-Head Expressivity

Statement

Each attention head produces a per-head output whose row space is rank-bounded by dkd_k: the attention-weighted value matrix in head hh lies in the column space of XWhVXW_h^V, so its rank is at most dkd_k. Concatenating HH heads gives a matrix of rank at most Hdk=dH d_k = d, and the output projection WOW^O mixes these HH rank-dkd_k pieces.

The structural observation is that MHA decomposes the layer's output into HH parallel rank-dkd_k channels with independent attention patterns, rather than a single rank-dd channel driven by one pattern. Different heads can attend to different subspaces of the key-value space, producing outputs that no single head of dimension dd with one attention matrix can produce.

Intuition

Think of each head as asking a different question about the input. Head 1 might attend based on syntactic relationships (subject to verb). Head 2 might attend based on semantic similarity. Head 3 might attend to positional proximity. By running these in parallel and combining, the model captures multiple types of dependencies simultaneously.

Proof Sketch

A single softmax attention head produces a convex combination of value vectors along one rank-dkd_k channel, so it cannot simultaneously place sharp mass on two incompatible sets of positions (e.g., nearest token and most semantically similar token). With two heads, each can place sharp mass on one pattern, and WOW^O linearly combines the two rank-dkd_k outputs. Each per-head output matrix has rank at most dkd_k, and the concatenation has rank at most Hdk=dH d_k = d, so once the heads jointly span the model dimension the linear part of the layer is saturated. This is a heuristic for why returns diminish past HH heads at fixed dd, not a hard expressivity theorem: each head still contributes its own input-dependent softmax pattern, so adding more narrow heads can in principle express attention combinations that a smaller number of heads cannot, even with the same total channel width. The rank-saturation argument should be read as motivating the standard dk=d/Hd_k = d/H choice, not as a proof that more heads strictly add nothing.

Why It Matters

This explains why multi-head attention outperforms single-head attention with the same total dimension. The computational cost is identical (same number of parameters and FLOPs), but the representational capacity increases.

Failure Mode

In practice, many heads learn redundant or near-identical patterns. Michel et al. (2019) showed that most heads can be pruned after training with minimal quality loss. The theoretical expressivity advantage is not always used. This observation motivates the reduced-head variants below.

Multi-Query Attention (MQA)

Definition

Multi-Query Attention

Multi-query attention (Shazeer, 2019) uses separate WhQW_h^Q projections per head but shares a single WKW^K and WVW^V across all heads:

headh=softmax(XWhQ(XWK)dk)XWV\text{head}_h = \text{softmax}\left(\frac{XW_h^Q (XW^K)^\top}{\sqrt{d_k}}\right) XW^V

All heads compute attention over the same key-value pairs but with different queries.

KV cache per token: 2dk2d_k (one set of keys and values instead of HH sets). This is an HH-fold reduction compared to MHA.

Tradeoff: MQA reduces KV cache by a factor of HH. For current production LLMs with HH in the 8-to-64 range (LLaMA-2-70B has H=64H = 64, LLaMA-3 family H=32H = 32, Mistral / Gemma-7B H=32H = 32), this is roughly an 8x to 64x cache reduction, enabling much longer contexts and larger batch sizes during inference. Older Shazeer 2019 MQA experiments on Transformer baselines with H=128H = 128 head counts saw the upper end of this range; modern decoder-only LLMs sit in the 8x to 32x band. Quality degrades slightly because heads can no longer attend to different subspaces of keys and values. For many tasks, this degradation is small (less than 1% on benchmarks).

Grouped-Query Attention (GQA)

Definition

Grouped-Query Attention

Grouped-query attention (Ainslie et al., 2023) divides the HH query heads into GG groups. Each group shares one set of key-value projections:

headh=softmax(XWhQ(XWg(h)K)dk)XWg(h)V\text{head}_h = \text{softmax}\left(\frac{XW_h^Q (XW_{g(h)}^K)^\top}{\sqrt{d_k}}\right) XW_{g(h)}^V

where g(h)=hG/Hg(h) = \lfloor hG/H \rfloor maps head hh to its group. Heads are 0-indexed: h{0,1,,H1}h \in \{0, 1, \ldots, H-1\} and g(h){0,1,,G1}g(h) \in \{0, 1, \ldots, G-1\}, so g(0)=0g(0) = 0 and g(H1)=(H1)G/H=G1g(H-1) = \lfloor (H-1)G/H \rfloor = G-1 whenever HGH \geq G. HH must be divisible by GG in practice so that groups are balanced (each group owns exactly H/GH/G query heads).

KV cache per token: 2Gdk2G \cdot d_k. When G=HG = H, GQA is MHA. When G=1G = 1, GQA is MQA.

GQA is the current standard for large language models (Llama 2/3, Mistral). It recovers most of MHA's quality while getting most of MQA's cache efficiency. Typical choices: H=32H = 32 query heads with G=8G = 8 KV groups (4x cache reduction).

Summary of Head-Sharing Variants

VariantKV headsKV cache per tokenQualityInference speed
MHAHH2Hdk=2d2Hd_k = 2dBestSlowest (large cache)
GQAGG (1<G<H1 < G < H)2Gdk2Gd_kNear MHAFast
MQA112dk2d_kSlightly worseFastest

Linear Attention

Standard attention computes softmax(QK/dk)V\text{softmax}(QK^\top/\sqrt{d_k})V, which requires materializing the n×nn \times n attention matrix. Linear attention replaces the softmax with a kernel function to avoid this.

Per-token autoregressive-decode cost is O(mdv)O(m d_v), constant in sequence length nn. This is the practical motivation. Causal linear attention admits a recurrent formulation: maintain a running m×dvm \times d_v state St=jtϕ(kj)vjS_t = \sum_{j \leq t} \phi(k_j) v_j^\top and a running mm-vector zt=jtϕ(kj)z_t = \sum_{j \leq t} \phi(k_j), updated in O(mdv)O(m d_v) per step. Training and prefill over a length-nn sequence still cost O(nmdv)O(n m d_v) (one update per position); the O(1)O(1)-in-nn advantage only applies to per-step generation after the recurrent state is warmed up, and only relative to the softmax cache read cost. Softmax attention, by contrast, grows linearly per decoding step because it reads the entire KV cache.

Proposition

Linear Attention as Kernel Approximation

Statement

Define ϕ:RdkRm\phi: \mathbb{R}^{d_k} \to \mathbb{R}^m as a feature map. Linear attention replaces the softmax attention with:

LinearAttn(Q,K,V)i=ϕ(qi)j=1nϕ(kj)vjϕ(qi)j=1nϕ(kj)\text{LinearAttn}(Q, K, V)_i = \frac{\phi(q_i)^\top \sum_{j=1}^{n} \phi(k_j) v_j^\top}{\phi(q_i)^\top \sum_{j=1}^{n} \phi(k_j)}

The key observation: j=1nϕ(kj)vj\sum_{j=1}^{n} \phi(k_j) v_j^\top is an m×dvm \times d_v matrix that can be computed once and reused for all queries. This gives O(nmdv)O(nmd_v) complexity instead of O(n2dk)O(n^2 d_k).

Intuition

Standard attention computes pairwise similarity between all query-key pairs, then uses these similarities to weight values. Linear attention factorizes this: first summarize all key-value pairs into a fixed-size matrix, then match each query against this summary. The summary is a compressed representation of the entire context.

Proof Sketch

Write the standard kernel attention: Attn(qi,K,V)=jκ(qi,kj)vj/jκ(qi,kj)\text{Attn}(q_i, K, V) = \sum_j \kappa(q_i, k_j) v_j / \sum_j \kappa(q_i, k_j) where κ(q,k)=ϕ(q)ϕ(k)\kappa(q, k) = \phi(q)^\top \phi(k). Substitute the kernel decomposition and rearrange: the numerator becomes ϕ(qi)(jϕ(kj)vj)\phi(q_i)^\top (\sum_j \phi(k_j) v_j^\top), which separates the query from the key-value aggregation.

Why It Matters

Linear attention achieves O(n)O(n) complexity in sequence length (for fixed mm and dvd_v). This enables attention over sequences of length 10510^5 or more, which is impractical with standard O(n2)O(n^2) attention.

Failure Mode

The feature map ϕ\phi must map into the non-negative orthant for the denominator ϕ(qi)jϕ(kj)\phi(q_i)^\top \sum_j \phi(k_j) to be a valid (positive) normalizer: otherwise cancellations can produce near-zero or negative denominators and the output blows up or flips sign. Naive random Fourier features take values in R\mathbb{R} and violate this. Two standard safe choices: ϕ(x)=ELU(x)+1\phi(x) = \text{ELU}(x) + 1 (Katharopoulos et al. 2020), elementwise and non-negative; and Performer's positive random features ϕ(x)exp(ωxx2/2)\phi(x) \propto \exp(\omega^\top x - \|x\|^2/2) (Choromanski et al. 2020), which give an unbiased estimate of the softmax kernel while staying non-negative.

Even with a valid feature map, ϕ\phi must approximate the softmax kernel well. On language modeling benchmarks, linear attention models consistently underperform softmax attention by a meaningful margin. The n×nn \times n matrix computed by softmax attention contains fine-grained pairwise information that the m×dvm \times d_v summary cannot fully capture when mnm \ll n.

Sparse Attention

Instead of replacing softmax, sparse attention restricts which query-key pairs are computed.

Local window attention: each token attends only to the ww nearest tokens. Complexity: O(nw)O(nw). Captures local dependencies but misses long-range ones.

Strided attention: attend to every ss-th token plus a local window. Captures periodic structure and some long-range dependencies.

Hash-based (LSH) attention (Kitaev et al., 2020): use locality-sensitive hashing to find the most similar keys for each query. Only compute attention for likely-high-similarity pairs. Expected complexity: O(nlogn)O(n \log n).

Block-sparse patterns: combine local windows with a few global tokens that attend to everything. Longformer (Beltagy et al., 2020) and BigBird (Zaheer et al., 2020) use this approach.

The practical problem: sparse patterns must be known in advance or computed cheaply. For autoregressive generation, where the attention pattern depends on the content being generated, fixed sparse patterns may miss important long-range dependencies.

Mistral 7B (Jiang et al. 2023) combines sliding-window attention with a rolling KV buffer, trading exact long-range coverage for practical throughput. The attention-sink strategy from StreamingLLM (Xiao et al. 2023) is an inference-time KV-retention trick and is not part of Mistral's trained architecture, though the two can be composed at deployment time. DeepSeek (2024) proposed Native Sparse Attention, training the sparse pattern jointly with the model rather than fixing it by hand.

Hardware-Aware and Deployment Variants

A parallel line of work leaves the mathematical definition of attention unchanged and attacks the memory hierarchy instead. This matters because on modern GPUs, nominally more efficient variants often lose to a well-implemented quadratic kernel.

FlashAttention (Dao et al. 2022, Dao 2023 for v2, Shah et al. 2024 for v3) fuses the attention computation into a single tiled kernel that never materializes the n×nn \times n matrix in HBM. Standard MHA with FlashAttention is frequently faster than linear or sparse variants because it is memory-bound, not FLOP-bound. This reframes the efficiency conversation: the asymptotic O(n2)O(n^2) is not the binding constraint at typical lengths; HBM traffic is.

PagedAttention (Kwon et al. 2023, the kernel behind vLLM) manages the KV cache as a paged virtual-memory allocator, cutting fragmentation and enabling high-throughput batched serving. It is orthogonal to MHA/GQA/MQA and composes with all of them.

Ring Attention (Liu et al. 2023) shards the KV sequence across devices and overlaps communication with computation in a ring, enabling context parallelism for sequences that do not fit in one device's memory.

Multi-head Latent Attention (MLA) (DeepSeek-V2, DeepSeek-AI 2024; DeepSeek-V3 2024) compresses keys and values into a low-rank latent vector per token, caching only the latent. This gives a KV cache smaller than MQA's while recovering near-MHA quality, and was the main inference-cost lever behind DeepSeek-V2/V3 at deployment scale.

Differential Attention (Ye et al. 2024) computes two softmax attention maps and takes their difference, reducing attention noise on irrelevant tokens. The math is standard attention applied twice; the change is in how the outputs are combined.

Hybrid architectures replace some attention layers with selective state space models. Mamba (Gu and Dao 2023) and Mamba-2 (Dao and Gu 2024) give O(n)O(n) sequence mixing with input-dependent dynamics. Jamba (Lieber et al. 2024) and Zamba interleave Mamba blocks with a minority of attention layers, trading some recall ability for long-context throughput.

Common Confusions

Watch Out

MQA does not reduce training FLOPs

MQA and GQA save memory for the KV cache during inference, which allows larger batch sizes and longer contexts. During training, the dominant cost is the matrix multiplications for QKQK^\top and the attention-weighted sum over values, which are the same for all variants (the shared K, V projections are simply broadcast). The training speedup from MQA is modest; the inference speedup is large.

Watch Out

Linear attention is not just softmax without exp

Dropping the exponential from softmax gives unnormalized attention, not linear attention. Linear attention specifically uses a kernel decomposition κ(q,k)=ϕ(q)ϕ(k)\kappa(q, k) = \phi(q)^\top \phi(k) to factorize the computation. The feature map ϕ\phi is what makes the factorization possible. Simply removing exp would give negative attention weights and no computational benefit.

Watch Out

Asymptotic complexity is not the binding constraint

Standard attention is O(n2)O(n^2) in both time and memory, but on current GPUs it is usually memory-bandwidth bound, not FLOP bound. FlashAttention (Dao et al. 2022) made standard MHA faster than many "more efficient" linear and sparse variants at sequence lengths up to tens of thousands, because it reduces HBM traffic rather than FLOPs. When comparing variants, compare wall-clock with a tuned kernel, not asymptotic complexity alone.

Canonical Examples

Example

KV cache savings with GQA

Consider a model with d=4096d = 4096, H=32H = 32 heads, dk=128d_k = 128, L=32L = 32 layers, serving at sequence length n=8192n = 8192 in float16 (2 bytes per value). The per-sequence KV cache is 2(kv heads)dkLn22 \cdot (\text{kv heads}) \cdot d_k \cdot L \cdot n \cdot 2 bytes. MHA: 2×32×128×32×8192×2=4.292 \times 32 \times 128 \times 32 \times 8192 \times 2 = 4.29 GB. GQA with G=8G = 8: 2×8×128×32×8192×2=1.072 \times 8 \times 128 \times 32 \times 8192 \times 2 = 1.07 GB (4x reduction, as expected from H/G=4H/G = 4). MQA: 2×1×128×32×8192×2=0.1342 \times 1 \times 128 \times 32 \times 8192 \times 2 = 0.134 GB (32x reduction). On an 80 GB GPU, the difference between 4.29 GB and 0.134 GB per sequence determines whether you can serve roughly 10 or 500 concurrent requests at this length. These are the numbers that drive deployment-time variant choice.

Exercises

ExerciseCore

Problem

A Transformer has H=64H = 64 query heads and uses GQA with G=8G = 8 KV groups. How many query heads share each KV group? What is the KV cache reduction factor compared to standard MHA?

ExerciseAdvanced

Problem

Show that standard softmax attention can be written as a kernel: Attn(qi,K,V)=jκ(qi,kj)vj/jκ(qi,kj)\text{Attn}(q_i, K, V) = \sum_j \kappa(q_i, k_j) v_j / \sum_j \kappa(q_i, k_j) with κ(q,k)=exp(qk/dk)\kappa(q, k) = \exp(q^\top k / \sqrt{d_k}). Why can this kernel not be exactly decomposed as ϕ(q)ϕ(k)\phi(q)^\top \phi(k) for a finite-dimensional ϕ\phi?

References

Canonical:

  • Vaswani et al., "Attention Is All You Need" (NeurIPS 2017), Section 3.2
  • Shazeer, "Fast Transformer Decoding: One Write-Head is All You Need" (arXiv 2019). MQA.

Head-sharing variants:

  • Ainslie et al., "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints" (EMNLP 2023)
  • DeepSeek-AI, "DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model" (arXiv 2024), Section 2.1. Multi-head Latent Attention.
  • DeepSeek-AI, "DeepSeek-V3 Technical Report" (arXiv 2024), Section 2. MLA at deployment scale.

Linear attention:

  • Katharopoulos et al., "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" (ICML 2020). ELU+1 feature map and the causal recurrence.
  • Choromanski et al., "Rethinking Attention with Performers" (ICLR 2021). Positive random features for unbiased softmax-kernel estimation.
  • Wang et al., "Linformer: Self-Attention with Linear Complexity" (arXiv 2020). Low-rank projection of keys and values.

Sparse and hybrid:

  • Kitaev et al., "Reformer: The Efficient Transformer" (ICLR 2020). LSH attention.
  • Beltagy et al., "Longformer: The Long-Document Transformer" (arXiv 2020)
  • Zaheer et al., "Big Bird: Transformers for Longer Sequences" (NeurIPS 2020)
  • Jiang et al., "Mistral 7B" (arXiv 2023), Section 2. Sliding window plus sink.
  • DeepSeek-AI, "Native Sparse Attention" (arXiv 2024). Trainable sparse pattern.
  • Ye et al., "Differential Transformer" (arXiv 2024). Noise-cancelling attention.

Hardware-aware and deployment:

  • Dao et al., "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (NeurIPS 2022)
  • Dao, "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" (arXiv 2023)
  • Shah et al., "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-Precision" (arXiv 2024)
  • Kwon et al., "Efficient Memory Management for Large Language Model Serving with PagedAttention" (SOSP 2023). vLLM.
  • Liu et al., "Ring Attention with Blockwise Transformers for Near-Infinite Context" (arXiv 2023)

State-space hybrids:

  • Gu and Dao, "Mamba: Linear-Time Sequence Modeling with Selective State Spaces" (arXiv 2023)
  • Dao and Gu, "Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality" (ICML 2024). Mamba-2.
  • Lieber et al., "Jamba: A Hybrid Transformer-Mamba Language Model" (arXiv 2024)

Pruning and redundancy:

  • Michel, Levy, Neubig, "Are Sixteen Heads Really Better than One?" (NeurIPS 2019). Empirical redundancy of heads.

Next Topics

  • KV cache: how these attention variants affect the memory and speed of autoregressive generation

Last reviewed: April 26, 2026

Canonical graph

Required before and derived from this topic

These links come from prerequisite edges in the curriculum graph. Editorial suggestions are shown here only when the target page also cites this page as a prerequisite.

Required prerequisites

2

Derived topics

3