torch-pipeline-parallelism
Torch Pipeline Parallelism
Overview
This skill provides guidance for implementing pipeline parallelism in PyTorch for distributed model training. Pipeline parallelism partitions a model across multiple devices/ranks, where each rank processes a subset of layers and communicates activations/gradients with neighboring ranks.
Key Concepts
Pipeline Parallelism Patterns
- AFAB (All-Forward-All-Backward): Process all microbatch forwards first, cache activations, then process all backwards. This is the most common pattern for pipeline parallelism.
- 1F1B (One-Forward-One-Backward): Interleave forward and backward passes for better memory efficiency but more complex scheduling.
Critical Components
- Model Partitioning: Divide model layers across ranks
- Activation Communication: Send/receive hidden states between ranks
- Gradient Communication: Send/receive gradients during backward pass
- Activation Caching: Store activations for backward pass computation
Implementation Approach
Step 1: Understand Model Architecture First
Before implementing, thoroughly understand the model being parallelized:
- Identify the layer structure (e.g.,
model.model.layersfor LLaMA) - Understand embedding layers (input embeddings, position embeddings)
- Identify the output head (e.g.,
lm_headfor language models) - Note any shared parameters or tied weights
Step 2: Plan Tensor Shape Handling
Critical distinction between ranks:
- Rank 0 (first stage): Receives integer token IDs with shape
[batch, seq_len] - Intermediate ranks: Receive hidden states with shape
[batch, seq_len, hidden_size] - Final rank: Must apply output head and compute loss
Create explicit shape handling logic for each case rather than assuming uniform input types.
Step 3: Design Communication Strategy
Use torch.distributed.P2POp for batched send/receive operations:
# Preferred: Batched P2P operations
ops = []
if rank > 0:
ops.append(dist.P2POp(dist.irecv, recv_tensor, rank - 1))
if rank < world_size - 1:
ops.append(dist.P2POp(dist.isend, send_tensor, rank + 1))
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
Avoid using bare dist.send/dist.recv as they are blocking and less efficient.
Step 4: Handle Gradient Flow Correctly
Critical pitfall: Using tensor.detach().requires_grad_(True) severs the computational graph.
Correct approach for maintaining gradient flow:
# For caching inputs that need gradients
input_cache = []
# During forward: cache the input tensor directly (not detached)
stage_input = received_tensor.requires_grad_(True)
input_cache.append(stage_input)
# During backward: use the cached tensor to compute gradients
# The gradient flows through the original tensor
Verify gradient connectivity with small test cases before full implementation.
Step 5: Implement Shape Communication
Communicate tensor shapes before data when shapes vary:
# Efficient: Single tensor for shape
shape_tensor = torch.tensor(list(tensor.shape), dtype=torch.long, device=device)
dist.send(shape_tensor, dst_rank)
# Then send the actual data
dist.send(tensor.contiguous(), dst_rank)
Avoid sending each dimension as a separate tensor.
Verification Strategies
1. Gradient Flow Verification
Create a minimal test to verify gradients flow correctly:
def test_gradient_flow():
# Create simple model partition
# Run forward/backward
# Check that model.parameters() have non-None gradients
for name, param in model.named_parameters():
assert param.grad is not None, f"No gradient for {name}"
assert not torch.all(param.grad == 0), f"Zero gradient for {name}"
2. Activation Shape Verification
Log shapes at each stage boundary:
# Before send
print(f"Rank {rank} sending shape: {tensor.shape}")
# After receive
print(f"Rank {rank} received shape: {tensor.shape}")
3. Single-Rank Testing
Test with world_size=1 to ensure the implementation handles the degenerate case:
- No communication should occur
- Model should function as standard single-device training
- All gradients should flow correctly
4. End-to-End Loss Comparison
Compare loss values between:
- Pipeline parallel implementation
- Standard single-device training (ground truth)
Values should match within numerical precision.
Common Pitfalls
1. Truncated Code Edits
When making large code changes:
- Prefer smaller, targeted edits over large rewrites
- Verify edit completeness by reading the file after each edit
- Use full file writes for major structural changes
2. Detach Breaking Gradient Flow
# WRONG: Severs computational graph
cached = tensor.detach().requires_grad_(True)
# RIGHT: Maintains graph connection for backward
cached = tensor.clone().requires_grad_(True)
# Or: simply keep reference to original tensor
3. Missing lm_head in Partitions
The output head (lm_head) is often separate from the layer list. Ensure:
- It's included in the final rank's computation
- Its parameters receive gradients
- It's not duplicated across ranks
4. Position Embeddings Handling
Position embeddings (especially rotary embeddings) require care:
- They may need explicit computation before the first layer
- The API varies between model implementations
- Test with the specific model architecture being used
5. Empty Partitions
When world_size > num_layers, some ranks may have no layers:
- Add explicit handling for empty partitions
- These ranks still need to forward activations
- Avoid division by zero in layer assignment
6. Variable Sequence Lengths
If microbatches have different sequence lengths:
- Communicate shapes before data
- Don't cache and reuse shapes across microbatches
- Consider padding strategies for efficiency
Code Organization
Structure the implementation clearly:
pipeline_parallel.py
├── partition_model() # Divide layers across ranks
├── get_stage_layers() # Get this rank's layer subset
├── forward_stage() # Single stage forward pass
├── backward_stage() # Single stage backward pass
├── send_activation() # Send tensor to next rank
├── recv_activation() # Receive tensor from prev rank
├── send_gradient() # Send gradient to prev rank
├── recv_gradient() # Receive gradient from next rank
└── train_step_pipeline() # Main training step orchestrator
Testing Checklist
Before considering implementation complete:
- Gradient flow verified for all model parameters
- Shapes correct at each stage boundary
- world_size=1 case works correctly
- Loss matches non-parallel baseline
- No communication deadlocks
- Memory usage scales appropriately with world_size
- Position embeddings handled correctly for the specific model
- Output head (lm_head) included and receives gradients