LLM Construction
Distributed Training Theory
Training frontier models requires thousands of GPUs. Data parallelism, model parallelism, and communication-efficient methods make this possible.
Prerequisites
Why This Matters
A single GPU cannot train a frontier language model. Modern large models can have tens or hundreds of billions of parameters and train over enormous token budgets. Even if you could fit the model on one device, the wall-clock time would be untenable.
Distributed training splits the work across hundreds or thousands of devices. But parallelism is not free: devices must communicate gradients, activations, or parameters, and this communication can bottleneck the entire system. The theory of distributed training is about understanding these tradeoffs and minimizing wasted computation.
Mental Model
Think of building a house. Data parallelism is hiring identical crews, each building a copy of the same wall using different bricks, then averaging their work. Model parallelism is splitting the blueprint so each crew builds a different section of the house. Pipeline parallelism is an assembly line: crew 1 pours foundations, passes to crew 2 for framing, crew 3 for roofing.
Each approach has different communication patterns. Data parallelism requires sharing gradients (the "average" step). Tensor parallelism requires sharing intermediate activations within each layer. Pipeline parallelism requires passing activations between stages. The art is choosing the right combination to minimize idle time and communication overhead.
Formal Setup and Notation
Let be the loss for parameters on example . We want to minimize using SGD or its variants, distributed across workers.
Data Parallelism
In data parallelism, each of workers holds a complete copy of the model . At each step:
- Each worker samples a minibatch of size
- Each worker computes its local gradient
- Workers average their gradients:
- Each worker updates:
The effective batch size is . Step 3 is typically implemented via AllReduce.
AllReduce
AllReduce is a collective communication primitive that computes the sum (or average) of vectors across all workers and distributes the result to every worker. In the standard - communication model, suppose each worker holds a tensor of bytes and we use ring AllReduce across workers. Then
where is the per-message latency cost and is the link bandwidth in bytes per second.
- The bandwidth term approaches for large .
- The latency term still grows linearly with .
So ring AllReduce is bandwidth-optimal asymptotically, but not literally independent of worker count once latency is accounted for.
Tensor Parallelism
Tensor parallelism (TP) splits individual layers across devices. For a linear layer , one common pattern is to partition across devices and let each device compute a partial matrix product:
Device computes (a slice of the output). Depending on the partition pattern, the next collective may be an AllGather, an AllReduce, or a ReduceScatter. The important point is not that tensor parallelism always gathers the full activation immediately, but that splitting one layer across devices requires repeated communication of activations or partial sums within the layer.
Pipeline Parallelism
Pipeline parallelism (PP) partitions the model's layers into sequential stages, each on a different device. A microbatch flows through stage 1, then stage 2, etc. Multiple microbatches are in flight simultaneously to fill the pipeline, but pipeline bubbles (idle time) occur at startup and shutdown.
For GPipe-style schedules with microbatches and stages, the startup / shutdown bubble fraction is approximately:
which shrinks as grows relative to .
Core Definitions
ZeRO (Zero Redundancy Optimizer) from DeepSpeed eliminates memory redundancy in data parallelism. Standard data parallelism replicates the full optimizer state, gradients, and parameters on every worker. ZeRO partitions these across workers:
With mixed-precision Adam the per-parameter memory budget is bytes: 2 bytes of FP16 parameters, 2 bytes of FP16 gradients, and 12 bytes of optimizer state (an FP32 master copy of the parameters, and Adam's and ). ZeRO stages partition these components in order.
- ZeRO Stage 1: partition the 12 bytes/param of optimizer state across workers. Each worker now holds bytes/param of optimizer state, down from 12 bytes/param. Parameters (2) and gradients (2) remain replicated. Per-worker total drops from 16 bytes/param to bytes/param.
- ZeRO Stage 2: additionally partition the 2 bytes/param of gradients. Workers compute full gradients locally during backward, but reduce-scatter them so each worker stores only its partition. Per-worker total drops to bytes/param.
- ZeRO Stage 3: additionally partition the 2 bytes/param of parameters. Per-worker total drops to bytes/param, but each forward and backward pass now needs an AllGather to reconstruct parameters, roughly doubling communication per step.
PyTorch FSDP is the modern practical family of implementations built around the same core idea: shard parameters / gradients / optimizer state across data-parallel workers and materialize full parameters only when a module is about to execute.
Gradient compression reduces communication volume:
- Gradient quantization: send gradients in low precision (e.g., FP16 or INT8) instead of FP32, halving or quartering bandwidth.
- Gradient sparsification: send only the top- gradient components. The error from dropped components is accumulated locally and sent in future rounds.
- PowerSGD: approximate the gradient matrix with a low-rank decomposition, reducing communication from to for rank .
Main Theorems
Ideal Linear Speedup of Synchronous Data-Parallel SGD
Statement
For data-parallel SGD with workers, each using local batch size , the effective batch size is . After steps, the convergence rate for a smooth non-convex objective satisfies:
With learning rate , this gives:
This is a factor of better than single-worker SGD with batch size and the same number of steps. Equivalently, to reach the same accuracy, data-parallel SGD needs times fewer steps in the idealized variance-limited regime. If communication is negligible, this translates into linear wall-clock speedup.
Intuition
Each worker contributes independent gradient estimates. Averaging independent estimates reduces variance by a factor of , just like averaging random variables. Less variance means you can take larger steps or reach the same accuracy faster. The speedup is linear until the batch size becomes so large that the variance reduction saturates (the "critical batch size").
Proof Sketch
Start from the smoothness inequality: . Take expectations. The averaged gradient has variance (each of workers averages independent samples). The bias is zero since each worker computes unbiased gradients. Telescope over steps and optimize the learning rate.
Why It Matters
Linear speedup is the best you can hope for from parallelism. This theorem says data-parallel SGD achieves the algorithmic version of it, up to a critical batch size beyond which increasing gives diminishing returns. The systems caveat is equally important: wall-clock speedup only follows if communication and load imbalance stay small relative to compute.
Failure Mode
Linear speedup breaks down when: (1) the batch size exceeds the critical batch size , after which additional workers reduce variance below what matters; (2) communication overhead exceeds computation savings; (3) large batch sizes require learning rate warmup and careful tuning to maintain convergence.
Communication-Computation Tradeoff
Statement
In data-parallel training with ring AllReduce, a simple no-overlap step-time model is:
where is the local computation time (forward + backward pass), is the gradient payload in bytes, is the per-hop latency, and the last term is the ring AllReduce bandwidth cost. The computation-to-communication ratio is therefore:
Training is efficient when this ratio is much greater than 1 (computation dominates). It is communication-bound when the ratio is less than 1.
Intuition
Every step, you do some math (forward and backward pass, time ) and some networking. If the network is slow relative to the GPU, workers spend most of their time waiting. Making each worker do more computation per communication round (larger local batch, gradient accumulation, more overlap) improves the ratio.
Proof Sketch
Ring AllReduce has two phases (reduce-scatter and all-gather). In each phase, every worker participates in communication rounds, which contributes the latency term . The payload moved per worker is bytes total, giving the bandwidth term. Layerwise overlap can reduce the non-overlapping communication cost, but this is the basic upper bound when overlap is weak or absent.
Why It Matters
This ratio determines whether adding more GPUs actually helps. For a model with parameters in FP16, the gradient payload is about bytes, so the ring bandwidth term is about seconds. At 100 Gbps bandwidth ( GB/s), this is about 0.32 seconds before latency. If the forward/backward pass takes 0.5 seconds, communication already consumes a large fraction of step time. Faster interconnects (InfiniBand, NVLink) or gradient compression are needed.
Canonical Examples
Scaling from 1 to 256 GPUs
A model with parameters (7B) trained with mixed-precision Adam needs 16 bytes/param: 14GB FP16 parameters, 14GB FP16 gradients, and 84GB of optimizer state (FP32 master copy + Adam's + , 12 bytes/param). Total: 112GB per worker in standard data parallelism.
With ZeRO Stage 1: the 84GB of optimizer state is split across workers, while parameters and gradients remain replicated. Each worker holds GB. At , this is GB per worker.
With ZeRO Stage 3: parameters, gradients, and optimizer state are all sharded. Each worker stores GB. At , this is 14GB per worker, but each forward and backward pass requires AllGather operations to reconstruct parameters.
Common Confusions
Data parallelism does not change the model
In exact arithmetic, synchronized data parallelism produces the same update as single-device training with the same effective batch size . In practice it is usually mathematically equivalent but not bitwise identical because floating-point reduction order, nondeterministic kernels, and RNG handling can differ. Asynchronous data parallelism is a different algorithm again.
Pipeline parallelism is not just splitting layers
Naive pipeline parallelism has the pipeline bubble problem: only one stage is active at a time, giving utilization . The key innovation (GPipe, PipeDream) is microbatching: split the minibatch into microbatches and overlap their execution. With , utilization approaches 100%, but memory increases because activations for all in-flight microbatches must be stored.
3D parallelism is not optional at scale
Frontier models use all three forms simultaneously: data parallelism across groups, tensor parallelism within a node (fast NVLink), and pipeline parallelism across nodes. The 3D configuration is a system design choice driven by hardware topology, not just model architecture.
ZeRO and FSDP buy memory by spending communication
Sharding optimizers, gradients, and parameters can reduce per-worker memory by nearly a factor of the data-parallel degree, but it is not a free lunch. The more aggressively you shard, the more often you must reduce-scatter or all-gather state. Memory savings and communication cost move together.
Summary
- Data parallelism gives linear speedup up to a critical batch size; beyond that, adding workers wastes compute
- The algorithmic speedup story and the wall-clock speedup story are different; communication can erase the gains
- Ring AllReduce has a bandwidth term of order and a latency term that still grows with worker count
- Tensor parallelism splits layers within a node (requires high bandwidth); pipeline parallelism splits layers across nodes (tolerates lower bandwidth)
- ZeRO eliminates memory redundancy: Stage 1 partitions optimizer states, Stage 2 adds gradients, Stage 3 adds parameters; FSDP is a modern sharded implementation family in the same spirit
- The computation-to-communication ratio determines whether training is GPU-bound or network-bound
Exercises
Problem
A model has parameters in FP32. Each worker has 100 Gbps network bandwidth. How long does AllReduce take, assuming the ring algorithm?
Problem
With workers and local batch size , the effective batch size is . If the critical batch size is , are we in the linear speedup regime? What happens if we double to ?
Problem
Explain the memory-communication tradeoff in ZeRO Stages 1-3. For each stage, what memory is saved per worker and what additional communication is required?
References
Canonical:
- Dean et al., "Large Scale Distributed Deep Networks" (NIPS 2012). The DistBelief paper that introduced asynchronous parameter servers for large-scale SGD.
- Patarasuk and Yuan, "Bandwidth optimal all-reduce algorithms for clusters of workstations" (2009). Ring AllReduce analysis underpinning the asymptotic used above.
- Rajbhandari et al., "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models" (arXiv:1910.02054, SC 2020). Introduces the three-stage partitioning used in this page.
- Huang et al., "GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism" (arXiv:1811.06965, 2019). Micro-batched pipeline parallelism and the bubble analysis.
Large-batch regime:
- Goyal et al., "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour" (arXiv:1706.02677, 2017). Linear scaling rule and warmup for large-batch SGD.
- McCandlish et al., "An Empirical Model of Large-Batch Training" (arXiv:1812.06162, 2018). Defines the critical batch size used above.
Gradient compression:
- Alistarh et al., "QSGD: Communication-Efficient SGD via Gradient Quantization and Encoding" (NeurIPS 2017).
- Lin et al., "Deep Gradient Compression: Reducing the Communication Bandwidth for Distributed Training" (arXiv:1712.01887, 2018). Sparsification with error feedback.
- Vogels et al., "PowerSGD: Practical Low-Rank Gradient Compression for Distributed Optimization" (NeurIPS 2019).
Current:
- Shoeybi et al., "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism" (arXiv:1909.08053, 2020).
- Narayanan et al., "Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM" (arXiv:2104.04473, SC 2021). 3D parallelism at scale.
- Zhao et al., "PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel" (arXiv:2304.11277, 2023). Practical sharded data parallel training at scale.
Next Topics
Distributed training connects to:
- Scaling laws that determine how many GPUs and how much data you need
- Mixed-precision training for reducing communication and memory further
- Federated learning for what changes when the workers no longer share IID centralized data
Last reviewed: April 26, 2026
Canonical graph
Required before and derived from this topic
These links come from prerequisite edges in the curriculum graph. Editorial suggestions are shown here only when the target page also cites this page as a prerequisite.
Required prerequisites
10- Optimizer Theory: SGD, Adam, and Muonlayer 3 · tier 1
- Batch Size and Learning Dynamicslayer 2 · tier 2
- Federated Learninglayer 3 · tier 2
- Parallel Processing Fundamentalslayer 5 · tier 2
- Broadcast Joins in Distributed Computelayer 4 · tier 3
Derived topics
2- Scaling Lawslayer 4 · tier 1
- Mixed Precision Traininglayer 3 · tier 2
Graph-backed continuations