A model fits on one GPU until it does not, and the moment it does not, you are no longer training a model — you are operating a distributed system that happens to compute gradients. The decision is never “should I shard” but “what do I shard”: parameters, gradients, optimizer states, activations, or the batch itself. Each choice has a different memory payoff and a different communication bill, and the entire discipline of large-scale training is choosing which bill to pay, on which wire, and whether you can hide it behind compute.
The memory you are actually fighting
Start with the bill, because the bill is what forces every decision downstream. The ZeRO paper splits GPU memory into two buckets, and the distinction is the most useful framing in the whole area. Model states are parameters, gradients, and optimizer states — they exist whether or not you run a single forward pass. Residual states are activations, temporary buffers, and fragmentation — they scale with batch and sequence length.
For mixed-precision Adam, the per-parameter model-state cost is fixed and larger than people expect:
That is 16 bytes per parameter ( = parameter count). A 7.5B model needs roughly 120 GB just for model states — before a single activation. The fat part is the 12 bytes of optimizer state, and that is exactly the part standard data parallelism replicates on every rank.
Plain DistributedDataParallel splits the batch and nothing else: every GPU holds a full copy of all 16 bytes/param, runs its shard of the batch, and an all-reduce averages gradients before the optimizer step. With GPUs you are storing identical copies of the optimizer state. At any nontrivial scale that is the dominant waste, and ZeRO/FSDP exist to delete it.
The residual states are a separate war, fought with different weapons (activation checkpointing, tensor and context parallelism). Keep the two buckets mentally distinct — a technique that helps one often does nothing for the other, and conflating them is how people pick the wrong tool.
ZeRO stages and their FSDP equivalents
ZeRO (Zero Redundancy Optimizer, Rajbhandari et al. 2019) is data parallelism with the replication removed, in three increasingly aggressive stages. PyTorch’s FullyShardedDataParallel is the same idea with different names.
| Stage | What it shards across DP ranks | Per-GPU model state | Comm vs DDP |
|---|---|---|---|
DDP / NO_SHARD | nothing (full replica) | 1× | |
| ZeRO-1 | optimizer states | 1× | |
ZeRO-2 / SHARD_GRAD_OP | + gradients | 1× | |
ZeRO-3 / FULL_SHARD (FSDP) | + parameters | ~1.5× |
The mechanics that make this work: an all-reduce is just a reduce-scatter followed by an all-gather. ZeRO-1 and ZeRO-2 exploit that you only need the optimizer state and gradients for the parameter shard you own, so they replace the all-reduce with a reduce-scatter (each rank gets its gradient shard, updates its parameters) plus an all-gather (broadcast updated params). Same total bytes on the wire as DDP — you get the memory savings essentially for free.
ZeRO-3 / FSDP FULL_SHARD goes further and shards the parameters themselves. No rank holds a full layer at rest. Just before a layer’s forward, FSDP issues an all-gather to reconstruct that layer’s full parameters, runs the compute, then immediately frees them; backward does the same and reduce-scatters the gradients. That extra parameter all-gather in both forward and backward is the ~1.5× communication cost (forward all-gather + backward all-gather + gradient reduce-scatter ≈ vs DDP’s ). In exchange, per-GPU model state drops to — the thing that lets you train a 70B model on a node of 8 GPUs.
# FSDP2 (torch.distributed.fsdp.fully_shard) — illustrative; pin to your torch version's docsimport torchfrom torch.distributed.device_mesh import init_device_meshfrom torch.distributed.fsdp import fully_shard
mesh = init_device_mesh("cuda", (num_nodes, 8), mesh_dim_names=("dp", "tp"))
for block in model.layers: # shard params/grads/opt-state of each block over the dp axis (== ZeRO-3) fully_shard(block, mesh=mesh["dp"])fully_shard(model, mesh=mesh["dp"])# reshard_after_forward=False keeps params gathered between fwd/bwd:# fewer all-gathers, more memory. Classic time/space dial.Two production details. First, HYBRID_SHARD shards within a node and replicates across nodes, which bounds the expensive parameter all-gather to fast intra-node links while doing only gradient sync over the slower inter-node fabric — usually the right default past one node. Second, FSDP2 reworked the internals from FSDP1’s flattened-parameter buffers to per-parameter sharding on DTensor, which composes far more cleanly with tensor parallelism. APIs in this area move quarterly; treat any snippet as a sketch and check current docs.
Tensor parallelism: splitting the matmul
ZeRO/FSDP shard storage but each GPU still computes whole layers. Tensor parallelism (Megatron-LM, Shoeybi et al. 2019) splits the computation — it cuts individual matrix multiplies across GPUs so a layer too big to compute on one device runs in pieces.
The MLP block is the clean case. For , Megatron splits the first weight column-wise: . Because GeLU is element-wise, — each GPU computes its half with no communication. The second weight is then split row-wise, each GPU multiplies its activation half by its shard, and the partial results are summed with an all-reduce. Attention parallelizes the same way along heads: the QKV projection is column-parallel (each rank owns a subset of heads, softmax is per-head so no cross-rank sync), and the output projection is row-parallel, ending in an all-reduce.
That is two all-reduces per transformer layer in the forward pass (one for attention, one for the MLP) and two more in the backward pass — four blocking collectives per layer per step, each moving an activation-sized tensor ().
from torch.distributed.tensor.parallel import ( ColwiseParallel, RowwiseParallel, parallelize_module,)# Megatron pattern: first matmul column-parallel, second row-parallel.# The row-parallel modules emit the all-reduce — on the critical path.plan = { "attn.qkv_proj": ColwiseParallel(), "attn.out_proj": RowwiseParallel(), # all-reduce "mlp.gate_up_proj": ColwiseParallel(), "mlp.down_proj": RowwiseParallel(), # all-reduce}parallelize_module(block, tp_mesh, plan)The defining property of TP: that all-reduce is on the critical path. The next operation literally cannot start until the reduced result lands, so unlike data-parallel gradient sync, you cannot just overlap it with compute. This is why TP is conventionally confined to a single node — typically 8 GPUs over NVLink — where bandwidth is high enough that the collective does not dominate. Stretch TP across slower inter-node links and the communication tax swamps the compute you parallelized. (Megatron’s sequence parallelism and the more recent async-TP work claw some of this back by splitting the collective and overlapping its pieces with the matmul, but the constraint stands: TP wants the fastest wire you have.)
Pipeline parallelism and the bubble
Pipeline parallelism is the cheapest axis to communicate and the most annoying to schedule. You partition the model by depth into stages — layers 1–8 on group A, 9–16 on group B, and so on — and the only communication is a point-to-point send of activations at each stage boundary (and gradients backward). No collectives, small payloads. That is what makes PP the axis you push across nodes when TP has saturated a single node.
The catch is the bubble. Run one batch through a 4-stage pipeline naively and stage 4 sits idle while stages 1–3 warm up, then stage 1 sits idle while the pipeline drains. The fix is micro-batching: split the batch into micro-batches and stream them so every stage stays busy. The idle fraction with the classic GPipe schedule is
for stages — so you want . Eight stages with 32 micro-batches is about 18% idle; the same eight stages with 4 micro-batches is over 60% idle and a waste of hardware.
The schedule also dictates activation memory. GPipe’s “all-forwards-then-all-backwards” must hold activations for all in-flight micro-batches before backward begins. 1F1B (one-forward-one-backward, from PipeDream) interleaves the two so each stage starts backward as soon as it can, capping live activations at roughly micro-batches instead of — same bubble, much lower peak memory, which is why it is the default. Interleaved 1F1B assigns each device several non-contiguous stages (“virtual stages”) to shrink the bubble further, paying with extra point-to-point traffic.
Long sequences and sparse experts
Two more axes show up once your workload leaves the “dense model, normal sequence length” regime.
Sequence and context parallelism attack the residual-state bucket — activation memory that scales with sequence length and that no amount of parameter sharding touches. Megatron’s sequence parallelism splits the operations TP leaves replicated (LayerNorm, dropout) along the sequence dimension, converting some all-reduces into reduce-scatter/all-gather pairs and cutting activation memory; it is a memory optimization layered onto TP. Context parallelism is the heavier tool for genuinely long sequences: shard the sequence itself across devices. Attention then needs every query block to see every key/value block, so Ring Attention passes K/V blocks around a ring of devices while each computes attention incrementally with an online softmax — overlapping the ring’s communication with the attention compute so the sequence length you can train scales with device count. (DeepSpeed-Ulysses is the all-to-all alternative, partitioning along heads instead of sequence.)
Expert parallelism is the axis for Mixture-of-Experts. Experts are distributed across devices; each token is routed to its top-k experts, which usually live on other devices. So an MoE layer does an all-to-all to dispatch tokens to the device holding their expert, runs the expert FFN, then a second all-to-all to combine results back. All-to-all is topology-sensitive and bursty in a way all-reduce is not, and at scale it, not the matmuls, is what MoE training spends its time on (GShard and Switch Transformers are the origin points here). Expert parallelism composes with the dense axes rather than replacing them: the attention blocks are still data/tensor parallel; only the expert FFNs ride the EP axis.
Composing the axes onto the bandwidth hierarchy
Real frontier-scale runs use 3D or 4D parallelism — data × tensor × pipeline, plus expert and/or context when the workload demands. The art is not in any single axis; it is in mapping each axis to the right tier of the interconnect so the most chatty collective sits on the fastest wire.
The mapping follows directly from each axis’s communication profile:
- Tensor parallelism — per-layer all-reduce on the critical path, highest bandwidth need → innermost, intra-node, over NVLink. Size ≈ 8.
- Expert parallelism — bursty all-to-all → also wants fast links, often sharing the intra-node fabric or the fastest inter-node tier.
- Pipeline parallelism — small point-to-point at stage boundaries → spans nodes over InfiniBand/Ethernet comfortably.
- Context parallelism — ring/all-to-all of K/V, overlappable → placed where bandwidth allows for your sequence length.
- Data parallelism / FSDP — one gradient reduce-scatter per step, fully overlappable → outermost, spanning the most nodes.
The other half of the game is overlap. FSDP prefetches the next layer’s parameter all-gather while the current layer computes, and overlaps gradient reduce-scatter with the backward pass — done well, the communication is nearly free. Pipeline parallelism overlaps one micro-batch’s communication with another’s compute by construction. Tensor parallelism is the stubborn one, which is exactly why it is boxed inside the node. A useful diagnostic: profile a step and look at GPU idle time during collectives. If the SMs are starving on an all-reduce, you have either put an axis on too slow a wire or failed to overlap it.
The memory levers that touch everything
Underneath all five axes sit three orthogonal knobs that change the memory equation regardless of how you shard. Keep the guidance qualitative — the exact APIs churn, so point yourself at current docs.
- Activation checkpointing (gradient checkpointing): drop most activations in the forward pass and recompute them in the backward pass. Trades roughly one extra forward of compute for a large cut in activation memory. Selective recomputation (recompute only the cheap-to-redo, memory-heavy ops like attention softmax) is the better default than checkpointing whole blocks.
- Mixed precision: bf16 compute with an fp32 master copy and fp32 optimizer moments is the standard. bf16’s fp32-range exponent removes the loss-scaling dance that fp16 requires; FP8 is increasingly viable on Hopper-class and newer hardware for both training and inference.
- Gradient accumulation: reach a large effective batch by summing gradients over several micro-steps instead of holding one giant batch in memory. Under FSDP it interacts with resharding — accumulating without resharding between micro-steps saves all-gathers at a memory cost, the same time/space dial as everywhere else.
None of these is a parallelism axis, but every parallelism decision shifts where they bite. Activation checkpointing is often what makes a TP-or-not call go away; gradient accumulation is often what lets you drop a parallelism axis entirely.
What to take away
The reach-for order in practice, cheapest engineering first:
- FSDP / ZeRO-3 first. It is the highest-leverage, lowest-complexity axis: shard the model states across data-parallel ranks, add activation checkpointing, and most “it does not fit” problems disappear with one wrapper and near-free communication.
- Add tensor parallelism when a single layer’s parameters or its activation-communication cost no longer fit comfortably on one node — and keep it inside the node, over NVLink, because its all-reduce is on the critical path.
- Add pipeline parallelism to span nodes once TP has filled one; its cheap point-to-point traffic tolerates slower inter-node links, but budget micro-batches () or the bubble eats your gains.
- Add context parallelism only when sequence length is the binding constraint, and expert parallelism only for MoE, where all-to-all becomes the thing you optimize.
The unifying mental model: data parallelism wastes memory by replicating model states; ZeRO/FSDP delete that replication for almost no extra communication; tensor and pipeline parallelism split the compute when even a single layer is too big, paying in collectives you must place carefully on the interconnect hierarchy; and context/expert parallelism are specialist axes for long sequences and sparse models. You are always trading memory for network traffic — so measure the traffic, put it on the fastest available wire, and hide it behind compute. The frameworks (Megatron-LM, DeepSpeed, torchtitan) encode these decisions; understanding why each axis lives where it does is what lets you read a config and know whether it will be bandwidth-bound before you burn the GPU-hours finding out.