Skip to main content

LLM Construction

Forgetting Transformer (FoX)

FoX adds a data-dependent forget gate to softmax attention. The gate down-weights unnormalized attention scores between past and present positions, giving the transformer a learned, recency-biased decay. FoX is FlashAttention-compatible, works without positional embeddings, and improves long-context language modeling and length extrapolation.

AdvancedTier 2FrontierFrontier watch~40 min

Why This Matters

Standard softmax attention has no built-in notion of recency. Position is injected externally through positional embeddings (sinusoidal, learned, RoPE, ALiBi), and the attention weights themselves are computed from content alone. Recurrent models do the opposite: LSTMs and state-space models carry a learned forget gate that decides, at each step, how much of the past to keep.

FoX (Forgetting Transformer) asks whether softmax attention can borrow that mechanism. The answer is yes, with a small and clean modification: compute a scalar forget gate per token and per head, take a running product of these gates between positions jj and ii, add the log of this product to the pre-softmax attention logits. That is it. No positional embeddings are required. The resulting model trains at the same speed as a standard transformer because the modification is compatible with the FlashAttention kernel. On long-context language modeling and length extrapolation, FoX beats a tuned RoPE baseline. On short-context downstream tasks it performs competitively.

The paper is Lin, Nikishin, He, and Courville, Forgetting Transformer: Softmax Attention with a Forget Gate, ICLR 2025 (arXiv:2503.02130). The contribution is narrow but instructive: a single data-dependent scalar per token, injected in the right place, gives transformers the recency inductive bias of recurrent models without giving up parallel training or the FlashAttention fast path.

Mental Model

Think of softmax attention as a content-based lookup with no decay: token ii can attend equally to any earlier token jj, and nothing in the attention weight itself punishes distance. Positional embeddings fix this indirectly by encoding iji - j as a content signal.

FoX adds a direct, multiplicative decay. Each token tt emits a forget value ft(0,1)f_t \in (0, 1) per attention head. The "keep mass" from token jj to token ii is the product

Fij==j+1if.F_{ij} = \prod_{\ell = j+1}^{i} f_\ell.

If every intermediate ff_\ell is near 1, the keep mass is near 1 and FoX recovers ordinary attention. If one intermediate ff_\ell is near 0, it closes the gate and the contribution from all tokens before it is suppressed for all future queries. The gate is data-dependent, so the model can decide, based on the current input, when to wipe context.

In log space this decay becomes additive. That is the key to efficient implementation: adding logFij\log F_{ij} to the attention logits is the same kind of operation ALiBi already performs, except the slope is data-dependent instead of a fixed head-specific constant.

Formal Setup

Definition

Forget Gate (FoX)

For each attention head hh and each token position tt, the forget gate is a scalar in (0,1)(0, 1) computed from the current hidden state xtx_t:

ft(h)=σ(wf(h)xt+bf(h)),ft(h)(0,1),f_t^{(h)} = \sigma\bigl( w_f^{(h) \top} x_t + b_f^{(h)} \bigr), \qquad f_t^{(h)} \in (0, 1),

where σ\sigma is the sigmoid, and wf(h)Rdw_f^{(h)} \in \mathbb{R}^d, bf(h)Rb_f^{(h)} \in \mathbb{R} are per-head learnable parameters. The gate is a scalar per head, not a vector, so its parameter cost is negligible.

Definition

Forgetting Attention

Let qi,kj,vjq_i, k_j, v_j denote query, key, and value vectors at positions i,ji, j for a fixed head. Define the cumulative log forget gate

Dij==j+1ilogf(ji),Dij=(j>i).D_{ij} = \sum_{\ell = j+1}^{i} \log f_\ell \quad (j \leq i), \qquad D_{ij} = -\infty \quad (j > i).

Forgetting Attention is the causal attention

oi=j=1iexp(qikj+Dij)vjj=1iexp(qikj+Dij).o_i = \frac{\sum_{j=1}^{i} \exp(q_i^{\top} k_j + D_{ij}) \, v_j}{\sum_{j=1}^{i} \exp(q_i^{\top} k_j + D_{ij})}.

Equivalently, letting Fij==j+1if=exp(Dij)F_{ij} = \prod_{\ell = j+1}^{i} f_\ell = \exp(D_{ij}),

oi=j=1iFijexp(qikj)vjj=1iFijexp(qikj).o_i = \frac{\sum_{j=1}^{i} F_{ij} \exp(q_i^{\top} k_j) \, v_j}{\sum_{j=1}^{i} F_{ij} \exp(q_i^{\top} k_j)}.

The factor FijF_{ij} down-weights the unnormalized attention score between positions ii and jj, with no change to the softmax denominator structure.

The modification lives entirely inside the attention logits. Adding a bias DijD_{ij} to qikjq_i^{\top} k_j is the same access pattern FlashAttention already handles for causal masks and ALiBi, so FoX reuses the FlashAttention kernel without a new fused implementation.

Main Theorems

Proposition

Forgetting Attention is Attention With Log-Domain Decay

Statement

Let Fij==j+1ifF_{ij} = \prod_{\ell = j+1}^{i} f_\ell with f(0,1)f_\ell \in (0,1). Then Forgetting Attention can be written as standard softmax attention on modified logits

s~ij=qikj+logFij,FoXAttn(qi,K,V)=softmaxj(s~ij)V.\tilde{s}_{ij} = q_i^{\top} k_j + \log F_{ij}, \quad \mathrm{FoXAttn}(q_i, K, V) = \mathrm{softmax}_j(\tilde{s}_{ij}) V.

In particular, when f1f_\ell \equiv 1 the bias vanishes and FoX reduces to standard causal softmax attention.

Intuition

The forget gate never enters the softmax denominator separately. It is folded into the logits as a position-dependent bias. This is the structural reason FoX keeps FlashAttention compatibility: the kernel only needs one extra bias term per (i,j)(i, j) pair, which can be computed on the fly from a prefix sum of logf\log f_\ell.

Proof Sketch

Start from the Forgetting Attention definition and factor Fij=exp(logFij)F_{ij} = \exp(\log F_{ij}). Then Fijexp(qikj)=exp(qikj+logFij)F_{ij} \exp(q_i^{\top} k_j) = \exp(q_i^{\top} k_j + \log F_{ij}). Both numerator and denominator share the same exponential, so the ratio is exactly the softmax of the shifted logits s~ij\tilde{s}_{ij}. When f=1f_\ell = 1 for all \ell, logFij=0\log F_{ij} = 0, recovering standard attention.

Why It Matters

This reformulation is what makes FoX trainable at transformer speed. Prefix sums of scalars are cheap, and the shifted logits drop into any attention kernel that supports causal masking plus a bias. Compare to linear attention variants where the decay must be baked into a custom recurrence: FoX keeps the softmax and the associativity of the kernel untouched.

Failure Mode

If all gates saturate near 1, FoX becomes a plain transformer with wasted parameters. If all gates saturate near 0, only the immediate previous token contributes, collapsing the model to a very short effective context. Initialization that pushes bf(h)b_f^{(h)} to a positive value (analogous to LSTM forget-gate bias tricks) is important to keep gates open early in training.

Proposition

Recency Bias of Forgetting Attention (Bounded-Logit Variant)

Statement

Fix a head and a query position ii. Suppose the forget gates satisfy ffˉ<1f_\ell \leq \bar{f} < 1 for all \ell, and the QK logits are bounded: qikjB\lvert q_i^{\top} k_j \rvert \leq B for all jij \leq i (as enforced by QK-norm in the FoX Pro block). Then the multiplicative weight on the key at position j<ij < i is at most

Fij==j+1iffˉij,F_{ij} = \prod_{\ell = j+1}^{i} f_\ell \leq \bar{f}^{\,i - j},

and the normalized attention weight Fijexp(qikj)/ZiF_{ij} \exp(q_i^{\top} k_j) / Z_i is bounded by fˉijexp(2B)\bar{f}^{\,i-j} \exp(2B), which decays at least geometrically in the gap iji - j uniformly in content similarity.

Intuition

Once the gate stays strictly below 1 and the QK logits are bounded, FoX imposes a ceiling on how much a far-away token can contribute, regardless of how well its key matches the query. Without the bounded-logit assumption, a single query-key pair with arbitrarily large inner product can dominate the denominator and void the geometric decay. QK-norm (part of the FoX Pro block) is what upgrades the recency bias from an empirical tendency to a provable bound.

Proof Sketch

Each factor in =j+1if\prod_{\ell = j+1}^{i} f_\ell is at most fˉ\bar{f}, and there are iji - j factors, giving FijfˉijF_{ij} \leq \bar{f}^{\,i - j}. The numerator contribution of key jj is Fijexp(qikj)vjF_{ij} \exp(q_i^{\top} k_j) v_j and is divided by ZiFiiexp(qiki)=exp(qiki)exp(B)Z_i \geq F_{ii} \exp(q_i^{\top} k_i) = \exp(q_i^{\top} k_i) \geq \exp(-B). So the normalized weight on vjv_j is at most fˉijexp(qikj+B)fˉijexp(2B)\bar{f}^{\,i-j} \exp(q_i^{\top} k_j + B) \leq \bar{f}^{\,i-j} \exp(2B). Without the bound on qikjq_i^{\top} k_j, this ratio is not controlled.

Why It Matters

Length extrapolation and stable long-context behavior require the model to refuse to attend to arbitrarily old tokens with arbitrary confidence. The Pro variant of FoX enforces this structurally through the combination of gated decay and QK-norm. The base variant without QK-norm exhibits recency bias empirically (as the paper's ablations in Section 5 show) but lacks a distribution-free guarantee. This distinction matters when reading FoX results: the Pro block is the configuration that carries the bounded-logit story.

Failure Mode

If QK-norm is dropped, a single large qikjq_i^{\top} k_j can overwhelm the geometric decay and place most of the attention mass on a distant token. The empirical recency bias of unnormalized FoX is only a tendency on typical data. A second failure mode: if the gate is not well-regulated, a single near-zero ff_\ell kills all information from before position \ell, which is harmful for needle-in-a-haystack retrieval where the gate must learn to stay open for task-relevant anchors.

The Pro Block

The paper also introduces a "Pro" block design that layers several small components from recurrent and efficient-attention literature around the Forgetting Attention core. The components are:

  • Output gate: a sigmoid gate applied to the attention output before the residual, similar to gated linear attention variants.
  • Output normalization: an RMSNorm on the attention output prior to the output projection.
  • QK-norm: RMSNorm applied to queries and keys before the dot product, which stabilizes logits at long context.
  • KV-shift: a simplified, learned, data-dependent token-shift on keys and values, borrowed from the short-convolution tradition in RWKV and Mamba-like designs.

Each piece is cheap in parameters and compatible with standard attention kernels. The paper reports that Pro blocks improve both FoX and the ordinary transformer, but the improvement for FoX is larger. The takeaway: the forget gate is the architectural change of interest, and the Pro block is a collection of well-trodden stabilizers that happen to pair well with it.

Historical Context

Multiplicative gating in sequence models goes back to LSTM (Hochreiter and Schmidhuber, 1997), which introduced input, forget, and output gates to stabilize gradient flow through long sequences. Highway Networks (Srivastava et al., 2015) carried the idea into feedforward depth. GRU (Cho et al., 2014) compressed LSTM gating to a single update gate.

In attention, ALiBi (Press et al., 2022) added a fixed, non-learned linear bias in the attention logits that decays with distance. FoX can be read as a data-dependent generalization of ALiBi: instead of a constant per-head slope, each head gets a sequence of gates whose cumulative log acts as the bias. The removal of explicit positional embeddings also echoes NoPE (Kazemnejad et al., 2023), which observed that causal decoders can learn positional structure from the causal mask alone. FoX's cumulative log-gate can itself be read as a learned positional signal.

Linear-attention and state-space variants reached similar conclusions from a different direction. GLA (Gated Linear Attention, Yang et al., 2024) is the closest linear-attention analogue: it attaches data-dependent scalar gates to a linear-attention recurrence. FoX is the softmax-attention analogue of GLA. HGRN (Qin et al., 2023) placed data-dependent forget gates in a hierarchical linear recurrence. Mamba (Gu and Dao, 2023) and Mamba-2 (Dao and Gu, 2024) built selective, data-dependent decay into a structured state-space model. FoX keeps softmax attention and inherits its expressivity while borrowing the per-step selective-forgetting primitive these recurrent lines developed.

Common Confusions

Watch Out

FoX does not gate the FFN

An earlier draft of this page (and some secondhand descriptions online) claim FoX adds a forget gate to the feed-forward block. That is wrong. The forget gate modifies the unnormalized attention scores inside the softmax, not the FFN output. The FFN block in FoX is the standard MLP or SwiGLU that the baseline uses.

Watch Out

The gate is a scalar per head, not a vector per dimension

Unlike LSTM forget gates, which are vectors that gate each hidden dimension independently, the FoX forget gate is a single scalar per attention head at each position. The scalar multiplies the whole attention contribution between positions ii and jj. This is what makes the cumulative product well-defined and FlashAttention-friendly.

Watch Out

FoX replaces positional embeddings, not positional information

FoX is trained and evaluated without RoPE, ALiBi, or sinusoidal positional embeddings. It is not the case that FoX is position-agnostic: the cumulative product FijF_{ij} depends on the gap iji - j and on the intermediate content, which gives the model an implicit, data-shaped sense of distance. The paper's claim is that this implicit signal is strong enough to make explicit positional embeddings unnecessary.

Watch Out

FoX is not a linear-attention method

FoX keeps softmax over the attention logits. Its decay enters as an additive bias in log space, not as a state update in a recurrence. That means FoX retains the full O(n2)O(n^2) training cost of attention, but also retains its expressivity. Compare to Mamba or RWKV, which use recurrent state to achieve O(n)O(n) at the cost of changing the computation.

Exercises

ExerciseCore

Problem

Show that when every forget gate f=1f_\ell = 1, Forgetting Attention is exactly equal to standard causal softmax attention. Then show that when every f=f(0,1)f_\ell = f \in (0, 1) is a head-specific constant, Forgetting Attention reduces to attention with an ALiBi-like linear bias. Give the exact slope.

ExerciseAdvanced

Problem

Let f(0,1)f_\ell \in (0, 1) be the per-head forget gates, and let Fij==j+1ifF_{ij} = \prod_{\ell = j+1}^{i} f_\ell. Suppose your hardware only supports causal attention kernels with a single additive bias term per (i,j)(i, j) pair. Describe an O(n)O(n) preprocessing step that allows Forgetting Attention to be computed by such a kernel with no change to the kernel itself. Why does this preprocessing not apply to a FFN-level forget gate?

References

Canonical:

  • Lin, Nikishin, He, and Courville, Forgetting Transformer: Softmax Attention with a Forget Gate (2025), arXiv:2503.02130. Published at ICLR 2025. Primary source for FoX, Forgetting Attention, and the Pro block.
  • Hochreiter and Schmidhuber, Long Short-Term Memory (1997), Section 2. Origin of the forget gate in recurrent networks.

Context for attention-level decay:

  • Press, Smith, and Lewis, Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation (ICLR 2022). ALiBi: fixed, non-data-dependent linear bias on attention logits. FoX generalizes this.
  • Su, Lu, Pan, et al., RoFormer: Enhanced Transformer with Rotary Position Embedding (2021), Sections 3 and 4. RoPE is the standard baseline FoX drops.

Context for recurrent and linear-attention decay:

  • Yang, Wang, Zhang, Shen, and Kim, Gated Linear Attention Transformers with Hardware-Efficient Training (2024), arXiv:2312.06635. GLA attaches data-dependent scalar gates to a linear-attention recurrence. FoX is the softmax-attention analogue; GLA is the linear-attention analogue.
  • Gu and Dao, Mamba: Linear-Time Sequence Modeling with Selective State Spaces (2023), Sections 3.2 and 3.5. Data-dependent selective decay in a state-space model, the closest non-attention counterpart to the FoX gate.
  • Dao and Gu, Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality (ICML 2024). Mamba-2 formalizes the duality between SSMs and linear attention, which clarifies why the forget-gate primitive transfers across architectures.
  • Qin, Yang, and Zhong, Hierarchically Gated Recurrent Neural Network for Sequence Modeling (NeurIPS 2023). HGRN: data-dependent forget gates in a hierarchical linear recurrence.
  • Peng et al., RWKV: Reinventing RNNs for the Transformer Era (EMNLP Findings 2023), Sections 4 and 5. Data-dependent decay in a recurrent formulation.

Context for positional information without explicit embeddings:

  • Kazemnejad, Padhi, Ramamurthy, Das, and Reddy, The Impact of Positional Encoding on Length Generalization in Transformers (NeurIPS 2023), arXiv:2305.19466. NoPE: causal decoders can learn positional structure from the causal mask alone. FoX's cumulative log-gate is itself a learned positional signal, and the paper reports no explicit positional embedding is needed.

Infrastructure:

  • Dao, FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning (2023). The kernel FoX reuses via its additive-bias structure.

Further reading:

  • Jurafsky and Martin, Speech and Language Processing (3rd ed. draft), Chapters 9 and 10. Background on attention and long-context modeling.

Benchmark caveat: The FoX paper's headline comparisons (long-context LM perplexity, length-extrapolation, and the Table 2 common-sense reasoning suite) hold parameter count, token budget, and training context length fixed across FoX, Transformer, and Transformer++ baselines. When quoting these numbers, state the setup explicitly. The "beats RoPE on long-context" claim is specifically at matched compute and training-context length; FoX's advantage widens when evaluated at sequence lengths beyond training.

Next Topics

Last reviewed: April 18, 2026

Canonical graph

Required before and derived from this topic

These links come from prerequisite edges in the curriculum graph. Editorial suggestions are shown here only when the target page also cites this page as a prerequisite.

Required prerequisites

5