A training step is a fight over three budgets: memory per device, compute per device, and communication between devices. Every parallelism strategy is a different way of paying those budgets for the same arithmetic.

To keep the strategies comparable, we use a single example. Most of the post comes back to one linear layer:

X[B,E]  @  W[E,E]Y[B,E]X[B, E] \; @ \; W[E, E] \rightarrow Y[B, E]

The cost is roughly 2BE22 B E^2 floating point operations. That is the train we want to keep moving. The question every strategy answers in its own way:

What is split, what is replicated, what communication repairs the split, and can that communication be hidden under matmul compute?

The “communication” piece is always one of a small set of collectives: all-gather, reduce-scatter, all-reduce, and all-to-all. See the appendix to get an intuitive feel for these collectives.

Data Parallelism

Replicate the weights, split the batch. Each device sees a different slice of XX‘s rows, multiplies by the same WW, and produces its slice of YY:

Data ParallelismEach chip owns a slice of the batch. W is replicated everywhere.
Chip 0
X
@
·
W
=
·
Y
Chip 1
X
@
·
W
=
·
Y
this chip's shardreplicatedmaterializedpartial sumelsewhere

Across chips, every WW is identical; each chip owns a different chunk of the batch dimension. The forward pass is literally just parallelized. There is a repair step comes at the end: the gradients of the replicated WW must take into account the full batch (which we sharded across devices), so an all-reduce of the W\partial W fixes this.

The win is throughput, not memory. Every device still stores the full model, the full gradient, and the full optimizer state.

No cross-device communication is necessary on the forward pass which is great — every chip just multiplies locally. We’ll revisit the backwards-pass communication — and how to overlap this communication with the compute — in the backwards-pass section.

The useful ratio is B/PB / P: forward pays nothing, but each step ends with an all-reduce of W\partial W (closer to E2E^2) against per-chip compute of BE2/PB E^2 / P, so DP wants B/PB / P large enough to keep the matmul well above that once-per-step collective.

FSDP

Stop replicating. The batch is still sharded as in DP, but now WW is sliced too — each chip stores only a strip of WW‘s rows. In the picture below, YY starts empty: chip ii has its XX slice and its WW shard, but it cannot compute its YY yet — the matmul Xi@WiX_i @ W_i doesn’t even compose. We must All-gather W to materialize the full WW on every chip; only then can each chip produce its batch slice of YY.

FSDPBatch sharded as in DP, but W's rows are partitioned across chips. The full W is gathered just in time, used, and freed.
Chip 0
X
@
·
W
=
·
Y
Chip 1
X
@
·
W
=
·
Y
this chip's shardreplicatedmaterializedpartial sumelsewhere

The mental model that makes this work: at any moment, only one or two layers are materialized (yellow in the picture above). The other many layers stay sharded. Memory shrinks by the shard count.

The performance question is whether the gather for layer N+1N{+}1 can finish while layer NN is still computing. The timeline below starts in the naive schedule — each matmul waits for its all-gather to land, leaving big idle gaps in the compute stream. Click Overlapped to slide the gathers underneath the matmuls: the compute stream becomes continuous and the step takes roughly half as long.

FSDP — Pipelined All-Gather
Layer N+1's weights are gathered during layer N's matmul. As long as the all-gather is shorter than the matmul, compute never waits.
time →computematmul L1matmul L2matmul L3matmul L4all-gatherAG W1AG W2AG W3AG W4
Total time roughly doubles. Compute waits for each all-gather to finish.
computecommunication

The same per-layer all-gather will return in the backwards pass — computing X\partial X also needs the full WW — so the backwards pass either repeats the gather or reuses a forward-cached copy. We’ll see that mirroring later.

The useful ratio simplifies to BB: per-layer compute scales like BE2/PB E^2 / P while the all-gather of WW is closer to E2/PE^2 / P, so larger per-chip batch buys room to hide the gather under the matmul.

Tensor Parallelism

A different shape. Tensor parallelism splits the matrix itself. There is one shared XX, but each chip holds a slice of its columns (along the contraction dimension EE), paired with the matching rows of WW. Each device then computes

Xi[B,E/P]  @  Wi[E/P,E]Y(i)[B,E]X_i \, [B, E/P] \; @ \; W_i \, [E/P, E] \rightarrow Y^{(i)} \, [B, E]

Notice the shape: Y(i)Y^{(i)} is the full output, not a slice. But it is a partial sum — only the ii-th of PP terms in the contraction. None of the chips alone has the correct answer.

Tensor Parallelism (row-parallel)One shared X, sliced along the contraction dimension E. W's rows are partitioned to match. Each chip computes a full-shape partial sum of Y.
Chip 0
X
@
·
W
=
·
Y
Chip 1
X
@
·
W
=
·
Y
this chip's shardreplicatedmaterializedpartial sumelsewhere

Reduce-scatter is exactly the collective the matrix shape demands: a sum across devices that re-shards the output. Above, the collective combines the striped partials and each chip is left with one column of the final YY.

We cannot overlap between layers here — the next layer’s input is exactly the current layer’s reduce-scattered output, so there is nothing yet to communicate ahead of. The overlap happens inside the layer, with the matmul work split into a staircase that interleaves partial computation with a reduce-scatter on the previous partial.

TP — Build the within-layer staircase
Split the matmul into chunks along the batch dim. Step through the algorithm; the yellow ring marks the chunk being computed this tick, and the dashed block marks the chunk being reduce-scattered at the same time.
Chip 0
X
@
·
W
=
·
Y
Chip 1
X
@
·
W
=
·
Y
time →computereduce-scattermm c1
Step 0 — Chunk 1 begins computing on each chip. No comm yet — the first partial is still being produced.
Step 0 / 5

The useful ratio is E/PE / P: compute scales like BE2/PB E^2 / P and the activation collective is closer to BEB E, so larger EE buys room to hide communication and larger PP makes each local matmul thinner.

Context Parallelism

So far the batch has been our only axis to split. But a training batch is really tokens laid out along a sequence, and a long sequence is its own axis. Context parallelism shards that sequence: each device owns a contiguous span of positions.

For the dense layers, the split is free. A linear layer treats every token independently, so sharding the sequence is just data parallelism wearing a different hat — each device multiplies its own tokens by the replicated WW with no communication at all.

Context ParallelismShard the sequence across chips. For the dense layer this is just DP — W is replicated and no communication is needed; attention is where CP pays.
Chip 0
X
@
·
W
=
·
Y
Chip 1
X
@
·
W
=
·
Y
this chip's shardreplicatedmaterializedpartial sumelsewhere

Attention is where it stops being free. Attention has every query attend to every key and value across the whole sequence, but each device only holds the K/V for its own span. The missing K/V has to come from the other devices.

The pragmatic answer is ring attention. Arrange the devices in a ring. Each device keeps its query block fixed and passes its K/V block to its neighbor, hop by hop, until every block has visited every device. As each block arrives, the device folds it into a running, online-softmax attention output — and the send of the next block overlaps the attention matmul on the current one. It is the same compute-hides-communication trick, now turned around a ring.

CP — ring attention accumulates over the sequence
Four devices, each owning one chunk of the sequence. A device keeps its query block Q_d fixed and rotates K/V blocks clockwise around the ring, folding each arriving block into its running attention output O_d. The yellow ring marks the block being attended to this tick; while it computes, the next K/V block is already hopping to the neighbor.
D0D1D2D3K/V →
Device 0
Q0
holdskv0
O0 (attn out)
Device 1
Q1
holdskv1
O1 (attn out)
Device 2
Q2
holdskv2
O2 (attn out)
Device 3
Q3
holdskv3
O3 (attn out)
Step 0 — each device attends to its own K/V block while the first ring hop sends every block on to the next device.
Step 0 / 4

Ring attention is not the only way. DeepSpeed-Ulysses repairs the split with a different collective: an all-to-all that transposes the sharded axis. Each device projects its slice of positions into queries, keys, and values; then an all-to-all redistributes those projections. Picture them as a grid of sequence positions × attention heads — sequence parallelism shards the rows, and the all-to-all flips it to shard the columns instead, so each device trades “my positions, all heads” for “all positions, my heads.” With the whole sequence in hand for its head group, every device runs full attention locally, and a second all-to-all on the output flips back to sequence-sharded for the FFN.

CP — Ulysses transposes the sharded axis with all-to-all
The grid is the attention's per-head Q/K/V, laid out as sequence (rows) × heads (columns); a cell's color is the device that owns it. An all-to-all flips which axis is sharded — sequence-sharded (own a row) becomes head-sharded (own a column) — so each device can run full attention over the whole sequence for its heads, then a second all-to-all flips back.
sequence-sharded · all heads
sequence ↓
heads →
h0
h1
h2
h3
s0
s1
s2
s3
each device owns a sequence chunk (one row)
Step 0 — sequence-sharded: each device holds its slice of the positions, with all of the attention heads.
Step 0 / 4

The contrast is the point. Ring attention spreads its communication around a ring of point-to-point hops it can hide under compute; Ulysses concentrates it into two all-to-alls — a hard sync on each side of attention, though each moves less volume per link as the device count grows. The catch is combinatorial: Ulysses can shard across at most as many devices as there are key/value heads — one head group per device — so head count, not memory, is what caps its reach.

For the ring, the useful ratio is S/PS / P: a device’s attention compute per hop scales like (S/P)2d(S/P)^2 d while the K/V block it forwards is only (S/P)d\sim (S/P) d — where SS is the sequence length and dd the head dimension — so a longer per-device sequence buys room to hide the ring hop under the matmul.

Expert Parallelism

A Mixture-of-Experts layer replaces the single FFN with many — say NN of them — plus a small router that sends each token to just one or two experts. Only a slice of the experts fire per token, so the layer buys parameters cheaply in compute. The catch: NN experts will not fit on one device.

Expert parallelism shards the experts: each device hosts N/PN / P of them. The router runs locally and tags every token with its chosen expert. Now the tokens are in the wrong place — a token sitting on device 0 may need an expert that lives on device 3.

Two collectives bracket the only local work:

  1. Dispatch — an all-to-all: every device sends each token to the device that owns its expert. Because the routing is arbitrary, every device ends up talking to every other — exactly the all-to-all the appendix primes.
  2. Expert FFN, local: each device runs its own experts on the tokens it just received. No communication.
  3. Combine — a second all-to-all: each result is shipped back to the device the token came from.
EP — route tokens to experts with all-to-all
Three devices, each holding one expert. The router assigns every token to an expert (its color). A dispatch all-to-all gathers each expert's tokens onto its device; the expert FFN runs locally; a combine all-to-all scatters the results back. A token with a white inset has been through its expert.
router tags tokens
Device 0 · E0
FFN0
tokens
Device 1 · E1
FFN1
tokens
Device 2 · E2
FFN2
tokens
Step 0 — the router tags each token with its expert (color), and every device holds a mix.
Step 0 / 4

The failure mode is imbalance. If the router favors one expert, that device’s queue dominates and the all-to-all stalls behind it; real systems cap each expert’s capacity and drop or reroute the overflow.

The useful ratio is EE: each token’s expert FFN costs E2\sim E^2 while the two all-to-alls move only its E\sim E-sized activation to its expert and back — so a wide hidden dimension hides the routing, provided the tokens stay balanced across experts.

The Backwards Pass

Forward is only half the step. The backwards pass needs gradients, and getting the gradients into the right shape — sharded, replicated, summed — is where each strategy spends its second round of communication.

The mental flip is simple. Wherever the forward did an all-gather, the backwards pass will reduce-scatter; wherever the forward did a reduce-scatter, the backwards pass will all-gather. This actually comes from the fact that the collectives are linear operators, and so in the backwards pass we take the transpose of each operator which gets you a different collective.

DP — One all-reduce at the end

W is replicated on every chip, so every chip has computed its own gradient W(i)\partial W^{(i)} from its batch slice. To make the gradients agree (so the optimizer steps the replicated parameters identically), DP fires one all-reduce of W\partial W at the end of the step.

DP backward — one all-reduce, end of step
W is replicated, so every chip computed a local ∂W from its batch slice. One all-reduce — really a reduce-scatter followed by an all-gather — combines them. That is the entire backward-pass communication budget for DP.
Chip 0
X
·
@
·
W
·
=
·
Y
·
Chip 1
X
·
@
·
W
·
=
·
Y
·
all-reduce: 0×
Forward done. Each chip has its slice of Y, computed with the replicated W (yellow). No gradients yet.
Step 0 / 3

The dependency chain

To see how the all-reduces can run concurrently with the next layer’s backwards pass, we briefly need two layers: XY1Y2X \to Y_1 \to Y_2, with weights W1W_1 and W2W_2 — so Y1=X@W1Y_1 = X @ W_1 and Y2=Y1@W2Y_2 = Y_1 @ W_2. (Activations between the layers are elided here; including them would add a local Hadamard factor to step 3 but does not change the dependency structure the overlap argument turns on.) The backwards pass visits four steps in order:

  1. Y2\partial Y_2 — the gradient of the loss with respect to layer 2’s output, handed down from above.
  2. W2=Y1Y2\partial W_2 = Y_1^{\top} \, \partial Y_2 — layer 2’s weight gradient. This is the gradient DP must all-reduce.
  3. Y1=Y2W2\partial Y_1 = \partial Y_2 \, W_2^{\top} — the gradient with respect to layer 1’s output, computed locally from Y2\partial Y_2 and the chip’s W2W_2 replica.
  4. W1=XY1\partial W_1 = X^{\top} \, \partial Y_1 — layer 1’s weight gradient.

The whole overlap argument is one sentence: Step 3 does not depend on Step 2’s all-reduce. It needs only the local Y2\partial Y_2 and the local replica of W2W_2, both already on the chip the moment Step 2 finishes. The all-reduced gradient is needed by the optimizer at the end of the step, not by the next backwards pass.

How the overlap works

Per layer, three things happen in parallel rather than in sequence:

  1. Compute W2\partial W_2 locally. The compute engine runs Y1Y2Y_1^{\top} \, \partial Y_2 on the Tensor Cores.
  2. Start All-Reducing W2\partial W_2.
  3. Concurrently compute layer 1’s un-reduced gradient.

Repeat that schedule for every layer in reverse and you get the staircase visualized below — each layer’s all-reduce hiding behind the next layer’s backwards-pass compute.

This is also where DP’s only compute/communication overlap lives: as each layer’s W\partial W becomes available, its all-reduce can run while earlier layers are still computing their backwards pass. The timeline below shows the staircase — comm streams hiding behind compute streams, layer by layer.

Data Parallelism — Backward Pass Overlap
Forward needs no comm at all. The interesting overlap is in backward: each layer's gradient all-reduce can run while earlier layers' backwards are still computing.
time →computeall-reduce
Backward pass about to start; no gradients yet.
Step 0 / 4
computecommunication

FSDP — Reduce-scatter in place of all-gather

The mirror of the forward FSDP pipeline. Each layer’s full WW is gathered (we need it for W\partial W), the gradient is computed at full shape, and reduce-scatter returns each chip to owning just its row of W\partial W — the transpose of forward’s per-layer all-gather. The per-layer overlap story is the same one the forward overlap timeline tells, just with the collectives swapped.

FSDP backward — reduce-scatter in place of forward's all-gather
One layer, end-to-end. Same chip-pair mirror as DP backward above; per-layer pipelining matches the forward FSDP overlap.
Chip 0
X
·
@
·
W
·
=
·
Y
·
Chip 1
X
·
@
·
W
·
=
·
Y
·
Start: W is sharded across chips, ∂W not yet computed.
Step 0 / 3
FSDP — Backward Pass Overlap
The mirror of forward FSDP, walked in reverse. Per layer: all-gather W, compute ∂W at full shape, reduce-scatter ∂W. The next layer's AG W overlaps the current layer's RS ∂W — the same pipeline, with AG↔RS swapped.
time →computecommunication
Forward complete; backward about to start — no gradients yet.
Step 0 / 4
computecommunication

TP — All-gather in place of reduce-scatter

Forward’s row-parallel TP fired a reduce-scatter on YY inside each layer; the backwards pass fires the transpose — an all-gather on Y\partial Y — before each chip computes its local W\partial W shard. The chunked staircase from the forward TP section carries over unchanged in shape.

TP backward — all-gather in place of reduce-scatter
∂W = Xᵀ · ∂Y, chunked along the batch dim. The diagonal entry of ∂W is fully local on each chip (chip i already has X[:, i] and ∂Y[:, i]), so it lands at step 0 with no comm dependency. The off-diagonal entry needs the other chip's ∂Y column; that AG is pipelined per-row, with each landed row feeding one chunk of the off-diagonal staircase. The yellow ring marks the active region; off-column ∂Y cells are absent until their AG lands, then pop solid; ∂W stripes mean off-diagonal partial sums.
Chip 0
X
·
·
∂Y
=
·
∂W
Chip 1
X
·
·
∂Y
=
·
∂W
time →computeall-gatherlocalAG r1
Step 0 — ∂Y arrives column-sharded. The diagonal of ∂W is fully local — chip i already has X[:, i] and ∂Y[:, i] — so it lands here in one shot, with no AG dependency. AG of row 1 fires concurrently to start staging the off-diagonal.
Step 0 / 5

CP — The K/V ring runs in reverse

Forward sent K/V blocks one way around the ring; the backwards pass sends the gradient blocks K\partial K and V\partial V back the other way, each device summing the partial contributions for the block it owns. The same per-hop overlap from the forward ring carries over.

CP backward — the K/V ring runs in reverse
Same four-device ring as the forward pass, the other way around. Each block's ∂K/∂V is a sum of contributions from every device that attended to it, so the partials ring back and accumulate. The yellow ring marks the partial summed in this tick.
D0D1D2D3∂K/∂V →
Device 0
carries∂kv0
∂K/∂V0
Device 1
carries∂kv1
∂K/∂V1
Device 2
carries∂kv2
∂K/∂V2
Device 3
carries∂kv3
∂K/∂V3
Step 0 — each device forms the local ∂K/∂V from the block it currently holds.
Step 0 / 4

EP — The same all-to-all pair, re-routing gradients

Dispatch and combine swap roles: an all-to-all ships each output gradient Y\partial Y to the device that ran its expert, the expert gradients compute locally, and a second all-to-all returns each input gradient X\partial X to the token’s origin device.

EP backward — the same all-to-all pair, re-routing gradients
Dispatch and combine swap roles. The output gradient ∂Y all-to-alls to the device that owns each token's expert, the expert grads compute locally, and a second all-to-all returns each input gradient ∂X home.
∂Y for own tokens
Device 0 · E0
∂FFN0
∂Y
Device 1 · E1
∂FFN1
∂Y
Device 2 · E2
∂FFN2
∂Y
Step 0 — each device starts with the output gradient ∂Y for its own tokens.
Step 0 / 4

The Unified Picture

The strategies differ less by vocabulary than by repair step.

StrategyWhat is split?What is replicated?Main repairFailure mode
Data Parallelismbatch rowsweights, gradients, optimizer stategradient syncmodel state does not fit
FSDPbatch rows + model stateone layer’s weights, transientlyall-gather, reduce-scatterexposed all-gather stalls
Tensor ParallelismXX’s contraction + WW‘s rowsnothing; each chip’s output is a partial sumreduce-scatter YY inside the layerlocal matmuls too small
Context Parallelismthe sequence (K/V across positions)weights; the dense layer is freeK/V ring around the devices (ring attention)sequence too short per device
Expert Parallelismthe experts (N/PN/P per device)nothing; tokens are routed to expertsdispatch + combine all-to-all, every MoE layerrouting imbalance overloads one expert

Matmul is the train. Collectives lay down the track. If the track arrives before the train needs it, the schedule looks compute-bound. If the track arrives late, the roofline drops to the network roof.

Further Reading

Appendix: The Collectives

Every collective the post links to lives here, and all four now earn their place above: all-gather, reduce-scatter, and all-reduce drive DP, FSDP, and TP, while the all-to-all is the workhorse of the expert-parallel routing — and the sequence-parallel transpose — you just stepped through.

Step through each panel; the captions describe what changes between frames.

The four collectives, in one picture
Each chip is one 4-unit rectangle showing the data it currently holds. Step through each panel to watch the colors spread and mix. Every reference to a collective in the post above links to one of these panels.
solid = owned data (AG, A2A)hatched = partial sum (1 contributor)hatched stripes = partial sum (multiple contributors)solid yellow = fully reducedgray = empty (sent away)
All-gather (ring algorithm)
Each chip starts owning one slot (its own color). At each round, every chip sends a slot one hop clockwise; senders keep their copy, so the number of populated slots per chip grows by one each round. After 3 rounds every rectangle is a rainbow of all four sources.
Chip 0
Chip 1
Chip 2
Chip 3
Start: each chip owns exactly one slot — its own (chip i has slot i, in its own color). The other three slots are empty (gray).
Step 0 / 3
Reduce-scatter (ring algorithm)
Every chip starts with the full vector as its own partial sum (rectangle fully chip-colored). At each round, partial sums spread one hop clockwise and mix with the receiver's contribution. After 3 rounds, each chip ends up with one fully reduced slot (yellow); the rest are sent away.
Chip 0
Chip 1
Chip 2
Chip 3
Start: every chip's rectangle is fully its own color — its own partial contribution to each of the 4 slots.
Step 0 / 3
All-reduce (ring algorithm)
Reduce-scatter (3 rounds) followed by all-gather (3 rounds): partial sums first mix down to one yellow slot per chip, then yellow broadcasts back out to fill every rectangle. 6 rounds total.
Chip 0
Chip 1
Chip 2
Chip 3
Same starting state as reduce-scatter — every chip's rectangle is fully its own color.
Step 0 / 6
All-to-all (pairwise swap schedule)
All-to-all transposes the sharding axis: each chip starts holding one shard of dim A split across destinations along dim B, and ends holding one shard of dim B with contributions from every source along dim A. Concretely below, every chip starts with its rectangle fully its own color (everything originated here, sliced by destination). At each round, pairs of chips swap one piece each.
Chip 0
Chip 1
Chip 2
Chip 3
Initial: each chip's rectangle is fully its own color — every piece originated here. Each piece is destined for a specific chip (the slots are numbered by destination).
Step 0 / 3