Chapter 4: Ring All-Reduce

The Algorithm That Powers Distributed Training

🍕 The Pizza Party Analogy

4 friends each made a different pizza. Goal: everyone tastes ALL pizzas.

Naive approach: Everyone sends their whole pizza to everyone else.
= 4 × 3 = 12 pizza transfers. Expensive!

Ring approach: Sit in a circle. Pass ONE SLICE to your right neighbor.
After 3 rounds, everyone has tasted all pizzas!
= 4 × 3 = 12 slice transfers, but slices are SMALLER than whole pizzas!

Why Ring All-Reduce?

Q: Why not just send everything to one node, combine, and broadcast back?

A: That's called "parameter server" - the central node becomes a bottleneck. With 100 GPUs sending 1GB each, that's 100GB hitting one poor server!

Q: Why a ring specifically?

A: Ring is bandwidth-optimal! Each node sends exactly 2×(N-1)/N of the data, regardless of world size. Trees and stars don't scale as well.

The Two Phases

💡 Ring All-Reduce = Reduce-Scatter + All-Gather

Phase 1 (Reduce-Scatter): Combine data, each peer ends up "owning" one chunk
Phase 2 (All-Gather): Share owned chunks so everyone has the full result

Phase 1: Reduce-Scatter

Setup: 4 peers, each with data [A, B, C, D] (4 chunks) Initial State: Peer 0: [A₀, B₀, C₀, D₀] Peer 1: [A₁, B₁, C₁, D₁] Peer 2: [A₂, B₂, C₂, D₂] Peer 3: [A₃, B₃, C₃, D₃] Ring: 0 → 1 → 2 → 3 → 0 Step 1: Each peer sends chunk[rank] to next, receives from prev ──────────────────────────────────────────────────────────── Peer 0: sends D₀ to Peer 1, receives C₃ from Peer 3 C₀ + C₃ = C₀₃ Peer 1: sends A₁ to Peer 2, receives D₀ from Peer 0 D₁ + D₀ = D₀₁ Peer 2: sends B₂ to Peer 3, receives A₁ from Peer 1 A₂ + A₁ = A₁₂ Peer 3: sends C₃ to Peer 0, receives B₂ from Peer 2 B₃ + B₂ = B₂₃ Step 2: Continue with accumulated chunks ──────────────────────────────────────────────────────────── Peer 0: sends C₀₃ to Peer 1, receives B₂₃ from Peer 3 B₀ + B₂₃ = B₀₂₃ Peer 1: sends D₀₁ to Peer 2, receives C₀₃ from Peer 0 C₁ + C₀₃ = C₀₁₃ ...and so on After N-1 steps: Peer 0 owns: A₀₁₂₃ (fully reduced!) Peer 1 owns: B₀₁₂₃ (fully reduced!) Peer 2 owns: C₀₁₂₃ (fully reduced!) Peer 3 owns: D₀₁₂₃ (fully reduced!)

Phase 2: All-Gather

After Reduce-Scatter: Peer 0: [A*, ?, ?, ?] Peer 1: [?, B*, ?, ?] Peer 2: [?, ?, C*, ?] Peer 3: [?, ?, ?, D*] (* = fully reduced) Step 1: Each peer sends their owned chunk ──────────────────────────────────────────────────────────── Peer 0: sends A* to Peer 1, receives D* from Peer 3 Peer 1: sends B* to Peer 2, receives A* from Peer 0 Peer 2: sends C* to Peer 3, receives B* from Peer 1 Peer 3: sends D* to Peer 0, receives C* from Peer 2 After step 1: Peer 0: [A*, ?, ?, D*] Peer 1: [A*, B*, ?, ?] Peer 2: [?, B*, C*, ?] Peer 3: [?, ?, C*, D*] ...continue for N-1 steps... Final Result (all peers identical): Peer 0: [A*, B*, C*, D*] Peer 1: [A*, B*, C*, D*] Peer 2: [A*, B*, C*, D*] Peer 3: [A*, B*, C*, D*]

The Math

Bandwidth Analysis

For N peers, each with data of size S:

Each peer sends ~2× its data, regardless of world size. That's optimal!

✏️ Calculate It!

You have 8 peers, each with a 1GB tensor. How much data does each peer SEND in total?

Hint: Use the formula 2×(N-1)×S/N

Answer: 2 × (8-1) × 1GB / 8 = 2 × 7 × 0.125GB = 1.75GB
Each peer sends 1.75GB to complete the all-reduce!

PCCL's Implementation

// Simplified ring all-reduce pseudocode
fn all_reduce(buffer: &mut [f32], world_size: usize, rank: usize) {
    let chunk_size = buffer.len() / world_size;
    let backup = buffer.clone();  // For fault tolerance!
    
    // Phase 1: Reduce-Scatter
    for step in 0..(world_size - 1) {
        let send_chunk = (rank - step) % world_size;
        let recv_chunk = (rank - step - 1) % world_size;
        
        // Send my chunk to next peer
        send_to_next(&buffer[send_chunk * chunk_size..]);
        
        // Receive chunk from previous peer
        let received = recv_from_prev();
        
        // Accumulate into my buffer
        for i in 0..chunk_size {
            buffer[recv_chunk * chunk_size + i] += received[i];
        }
        
        // Check for abort signal (fault tolerance!)
        if master.has_abort() {
            buffer.copy_from_slice(&backup);
            return Err(Aborted);
        }
    }
    
    // Phase 2: All-Gather (similar structure, no accumulation)
    // ...
}

⚠️ The Backup Buffer

Notice we clone the buffer BEFORE starting? That's crucial for fault tolerance!

If any peer fails mid-operation, we restore the backup and retry. Without this, you'd have corrupted half-reduced data.

Multiple Connections: The WAN Trick

TCP's dirty secret: A single TCP connection rarely saturates a WAN link. The receive window doesn't scale well with high latency.

PCCL's trick: Open MULTIPLE connections and run concurrent all-reduces!

Single Connection: ┌─────────────────────────────────────────┐ │ Tensor 1 │ Tensor 2 │ Tensor 3 │ │ ────────► │ ────────► │ ────────► │ │ │ │ │ │ Sequential, limited by TCP window │ │ Bandwidth: ~3 Gbit/s │ └─────────────────────────────────────────┘ 128 Concurrent Connections: ┌─────────────────────────────────────────┐ │ Conn 1: Tensor 1 ────────► │ │ Conn 2: Tensor 2 ────────► │ │ Conn 3: Tensor 3 ────────► │ │ ... │ │ Conn 128: Tensor 128 ────────► │ │ │ │ Parallel! Bandwidth: ~45 Gbit/s │ └─────────────────────────────────────────┘

Bit-Wise Determinism

💡 Why Ring All-Reduce is Naturally Deterministic

In the all-gather phase, each peer broadcasts its OWNED chunk. No computation, just copying.

Since everyone receives the SAME bytes, everyone ends up with IDENTICAL results!

⚠️ Quantization Breaks This!

If you quantize (compress) data for transmission:

// Problem: D(Q(x)) ≠ x
// Dequantized data isn't exactly the original!

// WRONG: Use local high-precision data
accumulate(my_chunk, received_chunk)  // my_chunk has extra precision!

// RIGHT: Quantize your own data too
accumulate(dequant(quant(my_chunk)), received_chunk)

Otherwise you get "lingering precision" that other peers don't have!

Code Magnet Exercise

Arrange these to implement reduce-scatter step:

recv_chunk = recv_from_prev() send_to_next(my_chunk) buffer[idx] += recv_chunk[i] check_abort_signal()

Correct order: ___ → ___ → ___ → ___

Answer: send_to_next → recv_from_prev → accumulate → check_abort
(Send and recv can be concurrent for full-duplex!)

"SCATTER then GATHER = Everyone gets the PLATTER"

Reduce-Scatter distributes work, All-Gather collects results!

Chapter Summary