A training step is a fight over three budgets: memory per device, compute per device, and communication between devices. Every parallelism strategy is a different way of paying those budgets for the same arithmetic.
To keep the strategies comparable, we use a single example. Most of the post comes back to one linear layer:
The cost is roughly floating point operations. That is the train we want to keep moving. The question every strategy answers in its own way:
What is split, what is replicated, what communication repairs the split, and can that communication be hidden under matmul compute?
The “communication” piece is always one of a small set of collectives: all-gather, reduce-scatter, all-reduce, and all-to-all. See the appendix to get an intuitive feel for these collectives.
Data Parallelism
Replicate the weights, split the batch. Each device sees a different slice of ‘s rows, multiplies by the same , and produces its slice of :
Across chips, every is identical; each chip owns a different chunk of the batch dimension. The forward pass is literally just parallelized. There is a repair step comes at the end: the gradients of the replicated must take into account the full batch (which we sharded across devices), so an all-reduce of the fixes this.
The win is throughput, not memory. Every device still stores the full model, the full gradient, and the full optimizer state.
No cross-device communication is necessary on the forward pass which is great — every chip just multiplies locally. We’ll revisit the backwards-pass communication — and how to overlap this communication with the compute — in the backwards-pass section.
The useful ratio is : forward pays nothing, but each step ends with an all-reduce of (closer to ) against per-chip compute of , so DP wants large enough to keep the matmul well above that once-per-step collective.
FSDP
Stop replicating. The batch is still sharded as in DP, but now is sliced too — each chip stores only a strip of ‘s rows. In the picture below, starts empty: chip has its slice and its shard, but it cannot compute its yet — the matmul doesn’t even compose. We must All-gather W to materialize the full on every chip; only then can each chip produce its batch slice of .
The mental model that makes this work: at any moment, only one or two layers are materialized (yellow in the picture above). The other many layers stay sharded. Memory shrinks by the shard count.
The performance question is whether the gather for layer can finish while layer is still computing. The timeline below starts in the naive schedule — each matmul waits for its all-gather to land, leaving big idle gaps in the compute stream. Click Overlapped to slide the gathers underneath the matmuls: the compute stream becomes continuous and the step takes roughly half as long.
The same per-layer all-gather will return in the backwards pass — computing also needs the full — so the backwards pass either repeats the gather or reuses a forward-cached copy. We’ll see that mirroring later.
The useful ratio simplifies to : per-layer compute scales like while the all-gather of is closer to , so larger per-chip batch buys room to hide the gather under the matmul.
Tensor Parallelism
A different shape. Tensor parallelism splits the matrix itself. There is one shared , but each chip holds a slice of its columns (along the contraction dimension ), paired with the matching rows of . Each device then computes
Notice the shape: is the full output, not a slice. But it is a partial sum — only the -th of terms in the contraction. None of the chips alone has the correct answer.
Reduce-scatter is exactly the collective the matrix shape demands: a sum across devices that re-shards the output. Above, the collective combines the striped partials and each chip is left with one column of the final .
We cannot overlap between layers here — the next layer’s input is exactly the current layer’s reduce-scattered output, so there is nothing yet to communicate ahead of. The overlap happens inside the layer, with the matmul work split into a staircase that interleaves partial computation with a reduce-scatter on the previous partial.
The useful ratio is : compute scales like and the activation collective is closer to , so larger buys room to hide communication and larger makes each local matmul thinner.
Context Parallelism
So far the batch has been our only axis to split. But a training batch is really tokens laid out along a sequence, and a long sequence is its own axis. Context parallelism shards that sequence: each device owns a contiguous span of positions.
For the dense layers, the split is free. A linear layer treats every token independently, so sharding the sequence is just data parallelism wearing a different hat — each device multiplies its own tokens by the replicated with no communication at all.
Attention is where it stops being free. Attention has every query attend to every key and value across the whole sequence, but each device only holds the K/V for its own span. The missing K/V has to come from the other devices.
The pragmatic answer is ring attention. Arrange the devices in a ring. Each device keeps its query block fixed and passes its K/V block to its neighbor, hop by hop, until every block has visited every device. As each block arrives, the device folds it into a running, online-softmax attention output — and the send of the next block overlaps the attention matmul on the current one. It is the same compute-hides-communication trick, now turned around a ring.
Ring attention is not the only way. DeepSpeed-Ulysses repairs the split with a different collective: an all-to-all that transposes the sharded axis. Each device projects its slice of positions into queries, keys, and values; then an all-to-all redistributes those projections. Picture them as a grid of sequence positions × attention heads — sequence parallelism shards the rows, and the all-to-all flips it to shard the columns instead, so each device trades “my positions, all heads” for “all positions, my heads.” With the whole sequence in hand for its head group, every device runs full attention locally, and a second all-to-all on the output flips back to sequence-sharded for the FFN.
The contrast is the point. Ring attention spreads its communication around a ring of point-to-point hops it can hide under compute; Ulysses concentrates it into two all-to-alls — a hard sync on each side of attention, though each moves less volume per link as the device count grows. The catch is combinatorial: Ulysses can shard across at most as many devices as there are key/value heads — one head group per device — so head count, not memory, is what caps its reach.
For the ring, the useful ratio is : a device’s attention compute per hop scales like while the K/V block it forwards is only — where is the sequence length and the head dimension — so a longer per-device sequence buys room to hide the ring hop under the matmul.
Expert Parallelism
A Mixture-of-Experts layer replaces the single FFN with many — say of them — plus a small router that sends each token to just one or two experts. Only a slice of the experts fire per token, so the layer buys parameters cheaply in compute. The catch: experts will not fit on one device.
Expert parallelism shards the experts: each device hosts of them. The router runs locally and tags every token with its chosen expert. Now the tokens are in the wrong place — a token sitting on device 0 may need an expert that lives on device 3.
Two collectives bracket the only local work:
- Dispatch — an all-to-all: every device sends each token to the device that owns its expert. Because the routing is arbitrary, every device ends up talking to every other — exactly the all-to-all the appendix primes.
- Expert FFN, local: each device runs its own experts on the tokens it just received. No communication.
- Combine — a second all-to-all: each result is shipped back to the device the token came from.
The failure mode is imbalance. If the router favors one expert, that device’s queue dominates and the all-to-all stalls behind it; real systems cap each expert’s capacity and drop or reroute the overflow.
The useful ratio is : each token’s expert FFN costs while the two all-to-alls move only its -sized activation to its expert and back — so a wide hidden dimension hides the routing, provided the tokens stay balanced across experts.
The Backwards Pass
Forward is only half the step. The backwards pass needs gradients, and getting the gradients into the right shape — sharded, replicated, summed — is where each strategy spends its second round of communication.
The mental flip is simple. Wherever the forward did an all-gather, the backwards pass will reduce-scatter; wherever the forward did a reduce-scatter, the backwards pass will all-gather. This actually comes from the fact that the collectives are linear operators, and so in the backwards pass we take the transpose of each operator which gets you a different collective.
DP — One all-reduce at the end
W is replicated on every chip, so every chip has computed its own gradient from its batch slice. To make the gradients agree (so the optimizer steps the replicated parameters identically), DP fires one all-reduce of at the end of the step.
The dependency chain
To see how the all-reduces can run concurrently with the next layer’s backwards pass, we briefly need two layers: , with weights and — so and . (Activations between the layers are elided here; including them would add a local Hadamard factor to step 3 but does not change the dependency structure the overlap argument turns on.) The backwards pass visits four steps in order:
- — the gradient of the loss with respect to layer 2’s output, handed down from above.
- — layer 2’s weight gradient. This is the gradient DP must all-reduce.
- — the gradient with respect to layer 1’s output, computed locally from and the chip’s replica.
- — layer 1’s weight gradient.
The whole overlap argument is one sentence: Step 3 does not depend on Step 2’s all-reduce. It needs only the local and the local replica of , both already on the chip the moment Step 2 finishes. The all-reduced gradient is needed by the optimizer at the end of the step, not by the next backwards pass.
How the overlap works
Per layer, three things happen in parallel rather than in sequence:
- Compute locally. The compute engine runs on the Tensor Cores.
- Start All-Reducing .
- Concurrently compute layer 1’s un-reduced gradient.
Repeat that schedule for every layer in reverse and you get the staircase visualized below — each layer’s all-reduce hiding behind the next layer’s backwards-pass compute.
This is also where DP’s only compute/communication overlap lives: as each layer’s becomes available, its all-reduce can run while earlier layers are still computing their backwards pass. The timeline below shows the staircase — comm streams hiding behind compute streams, layer by layer.
FSDP — Reduce-scatter in place of all-gather
The mirror of the forward FSDP pipeline. Each layer’s full is gathered (we need it for ), the gradient is computed at full shape, and reduce-scatter returns each chip to owning just its row of — the transpose of forward’s per-layer all-gather. The per-layer overlap story is the same one the forward overlap timeline tells, just with the collectives swapped.
TP — All-gather in place of reduce-scatter
Forward’s row-parallel TP fired a reduce-scatter on inside each layer; the backwards pass fires the transpose — an all-gather on — before each chip computes its local shard. The chunked staircase from the forward TP section carries over unchanged in shape.
CP — The K/V ring runs in reverse
Forward sent K/V blocks one way around the ring; the backwards pass sends the gradient blocks and back the other way, each device summing the partial contributions for the block it owns. The same per-hop overlap from the forward ring carries over.
EP — The same all-to-all pair, re-routing gradients
Dispatch and combine swap roles: an all-to-all ships each output gradient to the device that ran its expert, the expert gradients compute locally, and a second all-to-all returns each input gradient to the token’s origin device.
The Unified Picture
The strategies differ less by vocabulary than by repair step.
| Strategy | What is split? | What is replicated? | Main repair | Failure mode |
|---|---|---|---|---|
| Data Parallelism | batch rows | weights, gradients, optimizer state | gradient sync | model state does not fit |
| FSDP | batch rows + model state | one layer’s weights, transiently | all-gather, reduce-scatter | exposed all-gather stalls |
| Tensor Parallelism | ’s contraction + ‘s rows | nothing; each chip’s output is a partial sum | reduce-scatter inside the layer | local matmuls too small |
| Context Parallelism | the sequence (K/V across positions) | weights; the dense layer is free | K/V ring around the devices (ring attention) | sequence too short per device |
| Expert Parallelism | the experts ( per device) | nothing; tokens are routed to experts | dispatch + combine all-to-all, every MoE layer | routing imbalance overloads one expert |
Matmul is the train. Collectives lay down the track. If the track arrives before the train needs it, the schedule looks compute-bound. If the track arrives late, the roofline drops to the network roof.
Further Reading
- JAX scaling book — training chapter (the source for much of the roofline framing here)
- PyTorch FSDP2
fully_shard - PyTorch Tensor Parallelism
- NVIDIA NCCL collectives
- The Roofline Model
- Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM
Appendix: The Collectives
Every collective the post links to lives here, and all four now earn their place above: all-gather, reduce-scatter, and all-reduce drive DP, FSDP, and TP, while the all-to-all is the workhorse of the expert-parallel routing — and the sequence-parallel transpose — you just stepped through.
Step through each panel; the captions describe what changes between frames.