torch-tensor-parallelism
Tensor Parallelism Implementation Guide
This skill provides guidance for implementing tensor parallelism patterns in PyTorch, specifically for ColumnParallelLinear and RowParallelLinear layers that distribute computation across multiple devices.
Core Concepts
Tensor Parallelism Overview
Tensor parallelism splits individual layers across multiple devices to parallelize computation within a single forward/backward pass. The two primary patterns are:
-
ColumnParallelLinear: Shards weights along the output dimension (columns). Each device computes a portion of the output features, then results are concatenated via all-gather.
-
RowParallelLinear: Shards weights along the input dimension (rows). Each device computes partial outputs using its shard of the input, then results are summed via all-reduce.
Critical Implementation Requirement
When implementing tensor parallelism (especially in simulation or testing contexts), the forward pass must actually perform the collective operations, not just compute local shards:
- ColumnParallelLinear: Must concatenate outputs from all ranks (all-gather semantics)
- RowParallelLinear: Must sum outputs from all ranks (all-reduce semantics)
A common mistake is returning only the local shard and expecting an external framework to handle collective operations. Unless explicitly specified otherwise, the implementation should produce the final, complete output.
Implementation Approach
Step 1: Understand the Parallelism Pattern
Before implementing, clearly identify:
- Which dimension is being sharded (input features vs output features)
- What collective operation combines the results (all-gather vs all-reduce)
- Whether the implementation should simulate distributed execution or prepare for actual distributed execution
- How bias should be handled in the parallel context
Step 2: Weight Sharding
For weight matrix W of shape (out_features, in_features):
ColumnParallelLinear:
- Shard W along dim=0 (output features)
- Each rank gets W_shard of shape (out_features // world_size, in_features)
- Output shape per rank: (batch, out_features // world_size)
RowParallelLinear:
- Shard W along dim=1 (input features)
- Each rank gets W_shard of shape (out_features, in_features // world_size)
- Input to each rank should be corresponding shard of input
- Output shape per rank: (batch, out_features) - partial sum
Step 3: Forward Pass Implementation
ColumnParallelLinear Forward:
1. Compute local output: y_local = x @ W_shard.T + bias_shard (if bias per shard)
2. All-gather to concatenate: y = concat([y_0, y_1, ..., y_n], dim=-1)
3. Return complete output of shape (batch, out_features)
RowParallelLinear Forward:
1. Get input shard: x_shard = x[..., start:end] for this rank
2. Compute partial output: y_partial = x_shard @ W_shard.T
3. All-reduce to sum: y = sum([y_0, y_1, ..., y_n])
4. Add bias (only once, not per-rank): y = y + bias
5. Return complete output of shape (batch, out_features)
Step 4: Bias Handling
ColumnParallelLinear:
- Bias can be sharded along with output features
- Each rank adds its bias shard to its output shard
- After all-gather, the full bias has been applied
RowParallelLinear:
- Bias must NOT be sharded or added per-rank (would cause N-fold bias)
- Add bias only once after the all-reduce operation
- Typically only rank 0 adds bias, OR add bias after the sum
Verification Strategies
Mathematical Verification
When local testing is unavailable, verify implementation correctness through mathematical analysis:
- Simple example: Use a 2x4 weight matrix with world_size=2
- Trace computation: Manually compute what each rank produces
- Verify combination: Confirm all-gather/all-reduce produces correct final output
- Compare to baseline: Verify parallel output matches non-parallel computation
Shape Verification Checklist
- Input shape matches expected (batch, in_features)
- Weight shard shape matches expected partitioning
- Local output shape is correct for the parallelism type
- Final output shape matches (batch, out_features) - NOT the sharded dimension
Test Cases to Consider
- world_size=1: Trivial case, should match non-parallel implementation exactly
- world_size=2,4,8: Common parallel configurations
- Non-divisible dimensions: What happens when out_features % world_size != 0?
- Different batch sizes: Verify batch dimension is handled correctly
- With and without bias: Test both configurations
Common Pitfalls
Pitfall 1: Returning Local Shards Only
Symptom: Output tensor size is (out_features / world_size) instead of (out_features)
Cause: Implementation computes local shard but doesn't perform all-gather
Fix: Implement the collective operation to combine results from all ranks
Pitfall 2: Incorrect Bias Handling in RowParallelLinear
Symptom: Output values are N times larger than expected (where N is world_size)
Cause: Each rank adds the full bias, then values are summed
Fix: Add bias only once after all-reduce, not per-rank
Pitfall 3: Misinterpreting "Simulation" Requirements
Symptom: Implementation works for world_size=1 but fails for larger world sizes
Cause: Assuming external framework handles collective operations
Fix: Read requirements carefully - "as if using all_gather" means implement the operation
Pitfall 4: Truncated File Writes
Symptom: Implementation has syntax errors or missing code
Cause: File write operation was truncated
Fix: Always read back the complete file after writing to verify integrity
Pitfall 5: Wrong Dimension for Sharding
Symptom: Shape mismatch errors during matrix multiplication
Cause: Sharding along wrong dimension (rows vs columns confusion)
Fix: ColumnParallel shards output features (dim=0 of weight), RowParallel shards input features (dim=1 of weight)
Pre-Implementation Checklist
Before writing code, confirm understanding of:
- Which collective operation is needed (all-gather vs all-reduce)
- What the final output shape should be
- Whether simulation should actually perform collective ops or defer them
- How bias should be handled for this parallelism type
- What happens for edge cases (world_size=1, non-divisible dimensions)
Post-Implementation Checklist
After writing code:
- Read back the complete implementation file to verify no truncation
- Verify output shapes match expected dimensions for all world sizes
- Trace through a simple example manually to verify correctness
- Test trivial case (world_size=1) matches non-parallel baseline
- Test at least one non-trivial case (world_size=2 or 4)