Skip to main content

LLM Construction

Distributed Training Theory

Training frontier models requires thousands of GPUs. Data parallelism, model parallelism, and communication-efficient methods make this possible.

AdvancedTier 3CurrentSupporting~55 min

Why This Matters

A single GPU cannot train a frontier language model. Modern large models can have tens or hundreds of billions of parameters and train over enormous token budgets. Even if you could fit the model on one device, the wall-clock time would be untenable.

Distributed training splits the work across hundreds or thousands of devices. But parallelism is not free: devices must communicate gradients, activations, or parameters, and this communication can bottleneck the entire system. The theory of distributed training is about understanding these tradeoffs and minimizing wasted computation.

Mental Model

Think of building a house. Data parallelism is hiring NN identical crews, each building a copy of the same wall using different bricks, then averaging their work. Model parallelism is splitting the blueprint so each crew builds a different section of the house. Pipeline parallelism is an assembly line: crew 1 pours foundations, passes to crew 2 for framing, crew 3 for roofing.

Each approach has different communication patterns. Data parallelism requires sharing gradients (the "average" step). Tensor parallelism requires sharing intermediate activations within each layer. Pipeline parallelism requires passing activations between stages. The art is choosing the right combination to minimize idle time and communication overhead.

Formal Setup and Notation

Let f(θ;x)f(\theta; x) be the loss for parameters θ\theta on example xx. We want to minimize F(θ)=ExD[f(θ;x)]F(\theta) = \mathbb{E}_{x \sim \mathcal{D}}[f(\theta; x)] using SGD or its variants, distributed across KK workers.

Definition

Data Parallelism

In data parallelism, each of KK workers holds a complete copy of the model θ\theta. At each step:

  1. Each worker kk samples a minibatch BkB_k of size bb
  2. Each worker computes its local gradient gk=1bxBkf(θ;x)g_k = \frac{1}{b} \sum_{x \in B_k} \nabla f(\theta; x)
  3. Workers average their gradients: gˉ=1Kk=1Kgk\bar{g} = \frac{1}{K} \sum_{k=1}^{K} g_k
  4. Each worker updates: θθηgˉ\theta \leftarrow \theta - \eta \bar{g}

The effective batch size is KbKb. Step 3 is typically implemented via AllReduce.

Definition

AllReduce

AllReduce is a collective communication primitive that computes the sum (or average) of vectors across all workers and distributes the result to every worker. In the standard α\alpha-β\beta communication model, suppose each worker holds a tensor of nn bytes and we use ring AllReduce across KK workers. Then

Tring2(K1)α+2K1KnW,T_{\mathrm{ring}} \approx 2(K-1)\alpha + 2\frac{K-1}{K}\frac{n}{W},

where α\alpha is the per-message latency cost and WW is the link bandwidth in bytes per second.

  • The bandwidth term approaches 2n/W2n/W for large KK.
  • The latency term still grows linearly with KK.

So ring AllReduce is bandwidth-optimal asymptotically, but not literally independent of worker count once latency is accounted for.

Definition

Tensor Parallelism

Tensor parallelism (TP) splits individual layers across devices. For a linear layer Y=XWY = XW, one common pattern is to partition WW across KK devices and let each device compute a partial matrix product:

W=[W1,W2,,WK]W = [W_1, W_2, \ldots, W_K]

Device kk computes Yk=XWkY_k = XW_k (a slice of the output). Depending on the partition pattern, the next collective may be an AllGather, an AllReduce, or a ReduceScatter. The important point is not that tensor parallelism always gathers the full activation immediately, but that splitting one layer across devices requires repeated communication of activations or partial sums within the layer.

Definition

Pipeline Parallelism

Pipeline parallelism (PP) partitions the model's layers into SS sequential stages, each on a different device. A microbatch flows through stage 1, then stage 2, etc. Multiple microbatches are in flight simultaneously to fill the pipeline, but pipeline bubbles (idle time) occur at startup and shutdown.

For GPipe-style schedules with MM microbatches and SS stages, the startup / shutdown bubble fraction is approximately:

bubble=S1M+S1\text{bubble} = \frac{S - 1}{M + S - 1}

which shrinks as MM grows relative to SS.

Core Definitions

ZeRO (Zero Redundancy Optimizer) from DeepSpeed eliminates memory redundancy in data parallelism. Standard data parallelism replicates the full optimizer state, gradients, and parameters on every worker. ZeRO partitions these across workers:

With mixed-precision Adam the per-parameter memory budget is 16\approx 16 bytes: 2 bytes of FP16 parameters, 2 bytes of FP16 gradients, and 12 bytes of optimizer state (an FP32 master copy of the parameters, and Adam's mm and vv). ZeRO stages partition these components in order.

  • ZeRO Stage 1: partition the 12 bytes/param of optimizer state across KK workers. Each worker now holds 12/K12/K bytes/param of optimizer state, down from 12 bytes/param. Parameters (2) and gradients (2) remain replicated. Per-worker total drops from 16 bytes/param to 4+12/K4 + 12/K bytes/param.
  • ZeRO Stage 2: additionally partition the 2 bytes/param of gradients. Workers compute full gradients locally during backward, but reduce-scatter them so each worker stores only its partition. Per-worker total drops to 2+14/K2 + 14/K bytes/param.
  • ZeRO Stage 3: additionally partition the 2 bytes/param of parameters. Per-worker total drops to 16/K16/K bytes/param, but each forward and backward pass now needs an AllGather to reconstruct parameters, roughly doubling communication per step.

PyTorch FSDP is the modern practical family of implementations built around the same core idea: shard parameters / gradients / optimizer state across data-parallel workers and materialize full parameters only when a module is about to execute.

Gradient compression reduces communication volume:

  • Gradient quantization: send gradients in low precision (e.g., FP16 or INT8) instead of FP32, halving or quartering bandwidth.
  • Gradient sparsification: send only the top-kk gradient components. The error from dropped components is accumulated locally and sent in future rounds.
  • PowerSGD: approximate the gradient matrix with a low-rank decomposition, reducing communication from O(d)O(d) to O(dr/min(m,n))O(d \cdot r / \min(m,n)) for rank rr.

Main Theorems

Theorem

Ideal Linear Speedup of Synchronous Data-Parallel SGD

Statement

For data-parallel SGD with KK workers, each using local batch size bb, the effective batch size is B=KbB = Kb. After TT steps, the convergence rate for a smooth non-convex objective satisfies:

1Tt=1TEF(θt)2O(F(θ0)FηT+ηLσ2Kb)\frac{1}{T} \sum_{t=1}^{T} \mathbb{E}\|\nabla F(\theta_t)\|^2 \leq O\left(\frac{F(\theta_0) - F^*}{\eta T} + \eta L \frac{\sigma^2}{Kb}\right)

With learning rate η=O(Kb/T)\eta = O(\sqrt{Kb/T}), this gives:

O((F(θ0)F)Lσ2KbT)O\left(\sqrt{\frac{(F(\theta_0) - F^*)L\sigma^2}{KbT}}\right)

This is a factor of K\sqrt{K} better than single-worker SGD with batch size bb and the same number of steps. Equivalently, to reach the same accuracy, data-parallel SGD needs KK times fewer steps in the idealized variance-limited regime. If communication is negligible, this translates into linear wall-clock speedup.

Intuition

Each worker contributes independent gradient estimates. Averaging KK independent estimates reduces variance by a factor of KK, just like averaging KK random variables. Less variance means you can take larger steps or reach the same accuracy faster. The speedup is linear until the batch size becomes so large that the variance reduction saturates (the "critical batch size").

Proof Sketch

Start from the smoothness inequality: F(θt+1)F(θt)ηF(θt)Tgˉt+η2L2gˉt2F(\theta_{t+1}) \leq F(\theta_t) - \eta \nabla F(\theta_t)^T \bar{g}_t + \frac{\eta^2 L}{2} \|\bar{g}_t\|^2. Take expectations. The averaged gradient gˉt\bar{g}_t has variance σ2/(Kb)\sigma^2 / (Kb) (each of KK workers averages bb independent samples). The bias is zero since each worker computes unbiased gradients. Telescope over TT steps and optimize the learning rate.

Why It Matters

Linear speedup is the best you can hope for from parallelism. This theorem says data-parallel SGD achieves the algorithmic version of it, up to a critical batch size beyond which increasing KK gives diminishing returns. The systems caveat is equally important: wall-clock speedup only follows if communication and load imbalance stay small relative to compute.

Failure Mode

Linear speedup breaks down when: (1) the batch size KbKb exceeds the critical batch size Bcritσ2/F2B_{\text{crit}} \approx \sigma^2 / \|\nabla F\|^2, after which additional workers reduce variance below what matters; (2) communication overhead exceeds computation savings; (3) large batch sizes require learning rate warmup and careful tuning to maintain convergence.

Proposition

Communication-Computation Tradeoff

Statement

In data-parallel training with ring AllReduce, a simple no-overlap step-time model is:

TstepC+2(K1)α+2K1KnWT_{\text{step}} \approx C + 2(K-1)\alpha + 2\frac{K-1}{K}\frac{n}{W}

where CC is the local computation time (forward + backward pass), nn is the gradient payload in bytes, α\alpha is the per-hop latency, and the last term is the ring AllReduce bandwidth cost. The computation-to-communication ratio is therefore:

ratio=C2(K1)α+2K1KnW.\text{ratio} = \frac{C}{2(K-1)\alpha + 2\frac{K-1}{K}\frac{n}{W}}.

Training is efficient when this ratio is much greater than 1 (computation dominates). It is communication-bound when the ratio is less than 1.

Intuition

Every step, you do some math (forward and backward pass, time CC) and some networking. If the network is slow relative to the GPU, workers spend most of their time waiting. Making each worker do more computation per communication round (larger local batch, gradient accumulation, more overlap) improves the ratio.

Proof Sketch

Ring AllReduce has two phases (reduce-scatter and all-gather). In each phase, every worker participates in K1K-1 communication rounds, which contributes the latency term 2(K1)α2(K-1)\alpha. The payload moved per worker is 2K1Kn2\frac{K-1}{K}n bytes total, giving the bandwidth term. Layerwise overlap can reduce the non-overlapping communication cost, but this is the basic upper bound when overlap is weak or absent.

Why It Matters

This ratio determines whether adding more GPUs actually helps. For a model with 10910^9 parameters in FP16, the gradient payload is about n=2×109n = 2 \times 10^9 bytes, so the ring bandwidth term is about 4×109/W4 \times 10^9 / W seconds. At 100 Gbps bandwidth (12.5\approx 12.5 GB/s), this is about 0.32 seconds before latency. If the forward/backward pass takes 0.5 seconds, communication already consumes a large fraction of step time. Faster interconnects (InfiniBand, NVLink) or gradient compression are needed.

Canonical Examples

Example

Scaling from 1 to 256 GPUs

A model with d=7×109d = 7 \times 10^9 parameters (7B) trained with mixed-precision Adam needs 16 bytes/param: 14GB FP16 parameters, 14GB FP16 gradients, and 84GB of optimizer state (FP32 master copy + Adam's mm + vv, 12 bytes/param). Total: 112GB per worker in standard data parallelism.

With ZeRO Stage 1: the 84GB of optimizer state is split across KK workers, while parameters and gradients remain replicated. Each worker holds 28+84/K28 + 84/K GB. At K=8K = 8, this is 28+10.5=38.528 + 10.5 = 38.5 GB per worker.

With ZeRO Stage 3: parameters, gradients, and optimizer state are all sharded. Each worker stores 112/K112/K GB. At K=8K = 8, this is 14GB per worker, but each forward and backward pass requires AllGather operations to reconstruct parameters.

Common Confusions

Watch Out

Data parallelism does not change the model

In exact arithmetic, synchronized data parallelism produces the same update as single-device training with the same effective batch size KbKb. In practice it is usually mathematically equivalent but not bitwise identical because floating-point reduction order, nondeterministic kernels, and RNG handling can differ. Asynchronous data parallelism is a different algorithm again.

Watch Out

Pipeline parallelism is not just splitting layers

Naive pipeline parallelism has the pipeline bubble problem: only one stage is active at a time, giving utilization 1/S1/S. The key innovation (GPipe, PipeDream) is microbatching: split the minibatch into MM microbatches and overlap their execution. With MSM \gg S, utilization approaches 100%, but memory increases because activations for all in-flight microbatches must be stored.

Watch Out

3D parallelism is not optional at scale

Frontier models use all three forms simultaneously: data parallelism across groups, tensor parallelism within a node (fast NVLink), and pipeline parallelism across nodes. The 3D configuration is a system design choice driven by hardware topology, not just model architecture.

Watch Out

ZeRO and FSDP buy memory by spending communication

Sharding optimizers, gradients, and parameters can reduce per-worker memory by nearly a factor of the data-parallel degree, but it is not a free lunch. The more aggressively you shard, the more often you must reduce-scatter or all-gather state. Memory savings and communication cost move together.

Summary

  • Data parallelism gives linear speedup up to a critical batch size; beyond that, adding workers wastes compute
  • The algorithmic speedup story and the wall-clock speedup story are different; communication can erase the gains
  • Ring AllReduce has a bandwidth term of order n/Wn/W and a latency term that still grows with worker count
  • Tensor parallelism splits layers within a node (requires high bandwidth); pipeline parallelism splits layers across nodes (tolerates lower bandwidth)
  • ZeRO eliminates memory redundancy: Stage 1 partitions optimizer states, Stage 2 adds gradients, Stage 3 adds parameters; FSDP is a modern sharded implementation family in the same spirit
  • The computation-to-communication ratio determines whether training is GPU-bound or network-bound

Exercises

ExerciseCore

Problem

A model has d=1×109d = 1 \times 10^9 parameters in FP32. Each worker has 100 Gbps network bandwidth. How long does AllReduce take, assuming the ring algorithm?

ExerciseAdvanced

Problem

With K=64K = 64 workers and local batch size b=16b = 16, the effective batch size is 10241024. If the critical batch size is Bcrit=2048B_{\text{crit}} = 2048, are we in the linear speedup regime? What happens if we double KK to 128128?

ExerciseResearch

Problem

Explain the memory-communication tradeoff in ZeRO Stages 1-3. For each stage, what memory is saved per worker and what additional communication is required?

References

Canonical:

  • Dean et al., "Large Scale Distributed Deep Networks" (NIPS 2012). The DistBelief paper that introduced asynchronous parameter servers for large-scale SGD.
  • Patarasuk and Yuan, "Bandwidth optimal all-reduce algorithms for clusters of workstations" (2009). Ring AllReduce analysis underpinning the 2M/B2M/B asymptotic used above.
  • Rajbhandari et al., "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models" (arXiv:1910.02054, SC 2020). Introduces the three-stage partitioning used in this page.
  • Huang et al., "GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism" (arXiv:1811.06965, 2019). Micro-batched pipeline parallelism and the bubble analysis.

Large-batch regime:

  • Goyal et al., "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour" (arXiv:1706.02677, 2017). Linear scaling rule and warmup for large-batch SGD.
  • McCandlish et al., "An Empirical Model of Large-Batch Training" (arXiv:1812.06162, 2018). Defines the critical batch size BcritB_{\text{crit}} used above.

Gradient compression:

  • Alistarh et al., "QSGD: Communication-Efficient SGD via Gradient Quantization and Encoding" (NeurIPS 2017).
  • Lin et al., "Deep Gradient Compression: Reducing the Communication Bandwidth for Distributed Training" (arXiv:1712.01887, 2018). Sparsification with error feedback.
  • Vogels et al., "PowerSGD: Practical Low-Rank Gradient Compression for Distributed Optimization" (NeurIPS 2019).

Current:

  • Shoeybi et al., "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism" (arXiv:1909.08053, 2020).
  • Narayanan et al., "Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM" (arXiv:2104.04473, SC 2021). 3D parallelism at scale.
  • Zhao et al., "PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel" (arXiv:2304.11277, 2023). Practical sharded data parallel training at scale.

Next Topics

Distributed training connects to:

Last reviewed: April 26, 2026

Canonical graph

Required before and derived from this topic

These links come from prerequisite edges in the curriculum graph. Editorial suggestions are shown here only when the target page also cites this page as a prerequisite.

Required prerequisites

10

Derived topics

2

Graph-backed continuations