Skip to main content

Training Techniques

Mixed Precision Training

Train with FP16 or BF16 for speed while keeping FP32 master weights for accuracy. Loss scaling, overflow prevention, and when mixed precision fails.

CoreTier 2CurrentSupporting~45 min

Why This Matters

Training a large model in FP32 is significantly slower than FP16/BF16 on modern GPUs: 4-8x slower for matmul on H100/A100 tensor cores, and roughly 2x slower on non-tensor-core ops or older hardware. It also uses 2x more memory for weights and activations. Mixed precision training gives you most of this speedup while maintaining FP32 accuracy. Every large language model trained since 2018 uses some form of mixed precision. Understanding the underlying floating-point arithmetic is essential for diagnosing training failures.

Mental Model

Keep a "master copy" of weights in FP32. For each training step: cast weights to FP16/BF16, run the forward and backward pass in reduced precision (fast), then update the FP32 master weights with the computed gradients. The reduced-precision passes are fast because GPU tensor cores operate at 2x throughput for FP16/BF16 compared to FP32.

Formal Setup

Definition

Mixed Precision Training

A training procedure that maintains model weights θ\theta in FP32 (the master weights) while computing forward activations al=fl(Wlfp16al1)a_l = f_l(W_l^{\text{fp16}} \cdot a_{l-1}) and gradients in FP16 or BF16. Weight updates are applied in FP32:

θt+1fp32=θtfp32ηgtfp32\theta^{\text{fp32}}_{t+1} = \theta^{\text{fp32}}_t - \eta \cdot g_t^{\text{fp32}}

where gtfp32g_t^{\text{fp32}} is the gradient cast back to FP32 before the update.

FP16 vs BF16

FP16 (IEEE half precision) uses 1 sign bit, 5 exponent bits, and 10 mantissa bits. The representable range is approximately [6×108,65504][6 \times 10^{-8}, 65504] (the 6×1086 \times 10^{-8} figure is the minimum positive subnormal; the minimum positive normal is 6.1×105\approx 6.1 \times 10^{-5}).

BF16 (brain floating point) uses 1 sign bit, 8 exponent bits, and 7 mantissa bits. The representable range is approximately [1038,3.4×1038][10^{-38}, 3.4 \times 10^{38}].

The critical difference: BF16 has the same exponent range as FP32 but lower precision. FP16 has higher precision but a much smaller exponent range. For training, the exponent range matters more than mantissa precision. Gradients that are very small (below 6×1086 \times 10^{-8}) underflow to zero in FP16 but are representable in BF16. Activations that are large (above 65504) overflow in FP16 but not in BF16.

This is why BF16 has largely replaced FP16 for training on hardware that supports it (A100 and later GPUs, TPUs).

Loss Scaling

Definition

Loss Scaling

Loss scaling multiplies the loss by a constant s>1s > 1 before the backward pass, then divides gradients by ss after:

g~=1sθ(sL(θ))\tilde{g} = \frac{1}{s} \nabla_\theta (s \cdot L(\theta))

By linearity of differentiation, g~=θL(θ)\tilde{g} = \nabla_\theta L(\theta). The purpose is to shift the gradient distribution into the representable range of FP16 during the backward pass.

In practice, dynamic loss scaling starts with a large ss (e.g., 2162^{16}) and halves ss whenever an overflow (NaN/Inf) is detected. If no overflow occurs for NN consecutive steps, ss is doubled. This adapts to the gradient magnitude throughout training.

Main Theorems

Proposition

Loss Scaling Preserves Gradient Direction

Statement

Let L(θ)L(\theta) be a differentiable loss function and s>0s > 0 a finite scaling factor. If no floating-point overflow occurs during computation of θ(sL(θ))\nabla_\theta(s \cdot L(\theta)), then:

1sθ(sL(θ))=θL(θ)\frac{1}{s}\nabla_\theta(s \cdot L(\theta)) = \nabla_\theta L(\theta)

in exact arithmetic. In FP16 arithmetic, the scaled computation preserves gradient components that would otherwise underflow to zero, at the cost of potentially overflowing large components.

Intuition

Loss scaling shifts the entire gradient histogram to the right on a log scale. Components that were below the FP16 minimum (6×1086 \times 10^{-8}) are lifted into representable range. Components near the FP16 maximum (6550465504) may overflow. Dynamic scaling finds the sweet spot automatically.

Proof Sketch

By the chain rule, θ(sL)=sθL\nabla_\theta(s \cdot L) = s \cdot \nabla_\theta L. Dividing by ss recovers θL\nabla_\theta L exactly. In floating-point arithmetic, the multiplication by ss shifts the exponent of each gradient component by log2(s)\log_2(s) exponent positions, preventing underflow for components whose exponent was within log2(s)\log_2(s) of the minimum.

Why It Matters

Without loss scaling, FP16 training of deep networks fails because a large fraction of gradient components (often 50%+) fall below the FP16 minimum and become zero. Loss scaling is what makes FP16 training possible in practice. This is particularly important when combined with SGD convergence guarantees that assume nonzero gradient information.

Failure Mode

Loss scaling cannot help when gradients span a range wider than FP16 can represent (about 12 orders of magnitude from the smallest subnormal to the largest finite value). If the largest gradient component overflows even at scale s=1s = 1, or the smallest underflows even at maximum scale, mixed precision with FP16 breaks. BF16 avoids this by having a much wider exponent range (about 76 orders of magnitude, matching FP32).

When Mixed Precision Fails

Gradient accumulation errors. When accumulating gradients across microbatches in FP16, the running sum can lose precision. Small gradient contributions get rounded away when added to a large accumulator. The fix: accumulate in FP32.

Attention logit overflow. In transformers, the attention logits QKT/dQK^T / \sqrt{d} can exceed 65504 in FP16, causing NaN. This happens with long sequences or poorly scaled attention. The fix: compute attention in FP32, or use BF16 which handles the range.

Small weight updates. When the learning rate is very small and the gradient is moderate, the update ηg\eta \cdot g can underflow in FP16. The master weight strategy already handles this by updating in FP32, but naive implementations that skip master weights will fail.

Common Confusions

Watch Out

Mixed precision does not mean training in FP16

Mixed precision means using both FP16/BF16 and FP32 at different stages. Pure FP16 training (without FP32 master weights) diverges for most models. The "mixed" part is the key.

Watch Out

BF16 usually does not need loss scaling

Because BF16 has the same 8-bit exponent range as FP32, the gradient underflow problem that motivates FP16 loss scaling is essentially gone, and most BF16 training pipelines run without a scaling step. This is a practical advantage but not an unconditional guarantee: BF16's mantissa is much shorter than FP32 (7 bits vs 23), so there are settings — very small learning rates, gradient accumulation, second-order optimizers, FP8 / BF16 mixed precision — where round-off and accumulation issues can still cause silent precision loss. The right takeaway is "FP16 essentially requires loss scaling; BF16 typically does not", not "BF16 has no numerical failure modes".

Watch Out

Memory savings are not 2x

The master weights are still FP32. The memory savings come from FP16 activations (which dominate memory for large models) and FP16 gradients. For a model with NN parameters, you need 4N4N bytes for master weights plus 2N2N bytes for FP16 weights, versus 4N4N bytes for FP32 only. Activation memory savings are model-dependent.

Summary

  • Keep FP32 master weights, compute forward and backward in FP16/BF16
  • Loss scaling prevents gradient underflow in FP16 by shifting the gradient distribution
  • BF16 is preferred over FP16 for training because its wider exponent range eliminates most overflow and underflow issues
  • Dynamic loss scaling adapts automatically; start high and halve on overflow
  • Accumulate gradients in FP32 to avoid precision loss

Exercises

ExerciseCore

Problem

A gradient component has value 3×1083 \times 10^{-8} in FP32. Will this underflow to zero in FP16? If you apply loss scaling with s=210=1024s = 2^{10} = 1024, what is the scaled value, and does it survive in FP16?

ExerciseAdvanced

Problem

In a transformer with hidden dimension d=4096d = 4096 and sequence length L=8192L = 8192, the attention logits are QKT/dQK^T / \sqrt{d}. Under the i.i.d. standard-normal baseline for Q,KQ, K entries, estimate the typical maximum logit. Then explain why FP16 overflow still occurs in practice despite this baseline being far below 65504.

References

Canonical:

  • Micikevicius et al., "Mixed Precision Training" (2018), ICLR
  • Kalamkar et al., "A Study of BFLOAT16 for Deep Learning Training" (2019)

Current:

  • NVIDIA, "Training with Mixed Precision" documentation (2023)
  • Dehghani et al., "The Efficiency Misnomer" (2022), Section 4

Last reviewed: April 26, 2026

Canonical graph

Required before and derived from this topic

These links come from prerequisite edges in the curriculum graph. Editorial suggestions are shown here only when the target page also cites this page as a prerequisite.

Required prerequisites

6

Derived topics

1

Graph-backed continuations