Chapter 4: Ring All-Reduce
The Algorithm That Powers Distributed Training
🍕 The Pizza Party Analogy
4 friends each made a different pizza. Goal: everyone tastes ALL pizzas.
Naive approach: Everyone sends their whole pizza to everyone else.
= 4 × 3 = 12 pizza transfers. Expensive!
Ring approach: Sit in a circle. Pass ONE SLICE to your right neighbor.
After 3 rounds, everyone has tasted all pizzas!
= 4 × 3 = 12 slice transfers, but slices are SMALLER than whole pizzas!
Why Ring All-Reduce?
Q: Why not just send everything to one node, combine, and broadcast back?
A: That's called "parameter server" - the central node becomes a bottleneck. With 100 GPUs sending 1GB each, that's 100GB hitting one poor server!
Q: Why a ring specifically?
A: Ring is bandwidth-optimal! Each node sends exactly 2×(N-1)/N of the data, regardless of world size. Trees and stars don't scale as well.
The Two Phases
💡 Ring All-Reduce = Reduce-Scatter + All-Gather
Phase 1 (Reduce-Scatter): Combine data, each peer ends up "owning" one chunk
Phase 2 (All-Gather): Share owned chunks so everyone has the full result
Phase 1: Reduce-Scatter
Phase 2: All-Gather
The Math
Bandwidth Analysis
For N peers, each with data of size S:
- Reduce-Scatter: (N-1) steps, each sending S/N data = (N-1)×S/N
- All-Gather: (N-1) steps, each sending S/N data = (N-1)×S/N
- Total per peer: 2×(N-1)×S/N ≈ 2S for large N
Each peer sends ~2× its data, regardless of world size. That's optimal!
✏️ Calculate It!
You have 8 peers, each with a 1GB tensor. How much data does each peer SEND in total?
Hint: Use the formula 2×(N-1)×S/N
Each peer sends 1.75GB to complete the all-reduce!
PCCL's Implementation
// Simplified ring all-reduce pseudocode
fn all_reduce(buffer: &mut [f32], world_size: usize, rank: usize) {
let chunk_size = buffer.len() / world_size;
let backup = buffer.clone(); // For fault tolerance!
// Phase 1: Reduce-Scatter
for step in 0..(world_size - 1) {
let send_chunk = (rank - step) % world_size;
let recv_chunk = (rank - step - 1) % world_size;
// Send my chunk to next peer
send_to_next(&buffer[send_chunk * chunk_size..]);
// Receive chunk from previous peer
let received = recv_from_prev();
// Accumulate into my buffer
for i in 0..chunk_size {
buffer[recv_chunk * chunk_size + i] += received[i];
}
// Check for abort signal (fault tolerance!)
if master.has_abort() {
buffer.copy_from_slice(&backup);
return Err(Aborted);
}
}
// Phase 2: All-Gather (similar structure, no accumulation)
// ...
}
⚠️ The Backup Buffer
Notice we clone the buffer BEFORE starting? That's crucial for fault tolerance!
If any peer fails mid-operation, we restore the backup and retry. Without this, you'd have corrupted half-reduced data.
Multiple Connections: The WAN Trick
TCP's dirty secret: A single TCP connection rarely saturates a WAN link. The receive window doesn't scale well with high latency.
PCCL's trick: Open MULTIPLE connections and run concurrent all-reduces!
Bit-Wise Determinism
💡 Why Ring All-Reduce is Naturally Deterministic
In the all-gather phase, each peer broadcasts its OWNED chunk. No computation, just copying.
Since everyone receives the SAME bytes, everyone ends up with IDENTICAL results!
⚠️ Quantization Breaks This!
If you quantize (compress) data for transmission:
// Problem: D(Q(x)) ≠ x
// Dequantized data isn't exactly the original!
// WRONG: Use local high-precision data
accumulate(my_chunk, received_chunk) // my_chunk has extra precision!
// RIGHT: Quantize your own data too
accumulate(dequant(quant(my_chunk)), received_chunk)
Otherwise you get "lingering precision" that other peers don't have!
Code Magnet Exercise
Arrange these to implement reduce-scatter step:
recv_chunk = recv_from_prev()
send_to_next(my_chunk)
buffer[idx] += recv_chunk[i]
check_abort_signal()
Correct order: ___ → ___ → ___ → ___
(Send and recv can be concurrent for full-duplex!)
"SCATTER then GATHER = Everyone gets the PLATTER"
Reduce-Scatter distributes work, All-Gather collects results!
Chapter Summary
- Ring All-Reduce: Two phases - reduce-scatter + all-gather
- Bandwidth Optimal: Each peer sends ~2× its data, regardless of N
- Fault Tolerant: Backup buffer allows clean abort and retry
- WAN Trick: Multiple concurrent connections = more bandwidth
- Deterministic: All-gather phase ensures identical results
- Quantization Warning: Must quantize your own data too!