NYC
skills/letta-ai/skills/torch-tensor-parallelism

torch-tensor-parallelism

SKILL.md

Tensor Parallelism Implementation Guide

This skill provides guidance for implementing tensor parallelism patterns in PyTorch, specifically for ColumnParallelLinear and RowParallelLinear layers that distribute computation across multiple devices.

Core Concepts

Tensor Parallelism Overview

Tensor parallelism splits individual layers across multiple devices to parallelize computation within a single forward/backward pass. The two primary patterns are:

  1. ColumnParallelLinear: Shards weights along the output dimension (columns). Each device computes a portion of the output features, then results are concatenated via all-gather.

  2. RowParallelLinear: Shards weights along the input dimension (rows). Each device computes partial outputs using its shard of the input, then results are summed via all-reduce.

Critical Implementation Requirement

When implementing tensor parallelism (especially in simulation or testing contexts), the forward pass must actually perform the collective operations, not just compute local shards:

  • ColumnParallelLinear: Must concatenate outputs from all ranks (all-gather semantics)
  • RowParallelLinear: Must sum outputs from all ranks (all-reduce semantics)

A common mistake is returning only the local shard and expecting an external framework to handle collective operations. Unless explicitly specified otherwise, the implementation should produce the final, complete output.

Implementation Approach

Step 1: Understand the Parallelism Pattern

Before implementing, clearly identify:

  1. Which dimension is being sharded (input features vs output features)
  2. What collective operation combines the results (all-gather vs all-reduce)
  3. Whether the implementation should simulate distributed execution or prepare for actual distributed execution
  4. How bias should be handled in the parallel context

Step 2: Weight Sharding

For weight matrix W of shape (out_features, in_features):

ColumnParallelLinear:

  • Shard W along dim=0 (output features)
  • Each rank gets W_shard of shape (out_features // world_size, in_features)
  • Output shape per rank: (batch, out_features // world_size)

RowParallelLinear:

  • Shard W along dim=1 (input features)
  • Each rank gets W_shard of shape (out_features, in_features // world_size)
  • Input to each rank should be corresponding shard of input
  • Output shape per rank: (batch, out_features) - partial sum

Step 3: Forward Pass Implementation

ColumnParallelLinear Forward:

1. Compute local output: y_local = x @ W_shard.T + bias_shard (if bias per shard)
2. All-gather to concatenate: y = concat([y_0, y_1, ..., y_n], dim=-1)
3. Return complete output of shape (batch, out_features)

RowParallelLinear Forward:

1. Get input shard: x_shard = x[..., start:end] for this rank
2. Compute partial output: y_partial = x_shard @ W_shard.T
3. All-reduce to sum: y = sum([y_0, y_1, ..., y_n])
4. Add bias (only once, not per-rank): y = y + bias
5. Return complete output of shape (batch, out_features)

Step 4: Bias Handling

ColumnParallelLinear:

  • Bias can be sharded along with output features
  • Each rank adds its bias shard to its output shard
  • After all-gather, the full bias has been applied

RowParallelLinear:

  • Bias must NOT be sharded or added per-rank (would cause N-fold bias)
  • Add bias only once after the all-reduce operation
  • Typically only rank 0 adds bias, OR add bias after the sum

Verification Strategies

Mathematical Verification

When local testing is unavailable, verify implementation correctness through mathematical analysis:

  1. Simple example: Use a 2x4 weight matrix with world_size=2
  2. Trace computation: Manually compute what each rank produces
  3. Verify combination: Confirm all-gather/all-reduce produces correct final output
  4. Compare to baseline: Verify parallel output matches non-parallel computation

Shape Verification Checklist

  • Input shape matches expected (batch, in_features)
  • Weight shard shape matches expected partitioning
  • Local output shape is correct for the parallelism type
  • Final output shape matches (batch, out_features) - NOT the sharded dimension

Test Cases to Consider

  1. world_size=1: Trivial case, should match non-parallel implementation exactly
  2. world_size=2,4,8: Common parallel configurations
  3. Non-divisible dimensions: What happens when out_features % world_size != 0?
  4. Different batch sizes: Verify batch dimension is handled correctly
  5. With and without bias: Test both configurations

Common Pitfalls

Pitfall 1: Returning Local Shards Only

Symptom: Output tensor size is (out_features / world_size) instead of (out_features)

Cause: Implementation computes local shard but doesn't perform all-gather

Fix: Implement the collective operation to combine results from all ranks

Pitfall 2: Incorrect Bias Handling in RowParallelLinear

Symptom: Output values are N times larger than expected (where N is world_size)

Cause: Each rank adds the full bias, then values are summed

Fix: Add bias only once after all-reduce, not per-rank

Pitfall 3: Misinterpreting "Simulation" Requirements

Symptom: Implementation works for world_size=1 but fails for larger world sizes

Cause: Assuming external framework handles collective operations

Fix: Read requirements carefully - "as if using all_gather" means implement the operation

Pitfall 4: Truncated File Writes

Symptom: Implementation has syntax errors or missing code

Cause: File write operation was truncated

Fix: Always read back the complete file after writing to verify integrity

Pitfall 5: Wrong Dimension for Sharding

Symptom: Shape mismatch errors during matrix multiplication

Cause: Sharding along wrong dimension (rows vs columns confusion)

Fix: ColumnParallel shards output features (dim=0 of weight), RowParallel shards input features (dim=1 of weight)

Pre-Implementation Checklist

Before writing code, confirm understanding of:

  • Which collective operation is needed (all-gather vs all-reduce)
  • What the final output shape should be
  • Whether simulation should actually perform collective ops or defer them
  • How bias should be handled for this parallelism type
  • What happens for edge cases (world_size=1, non-divisible dimensions)

Post-Implementation Checklist

After writing code:

  • Read back the complete implementation file to verify no truncation
  • Verify output shapes match expected dimensions for all world sizes
  • Trace through a simple example manually to verify correctness
  • Test trivial case (world_size=1) matches non-parallel baseline
  • Test at least one non-trivial case (world_size=2 or 4)
Weekly Installs
20
Repository
letta-ai/skills
First Seen
Jan 24, 2026
Installed on
claude-code15
gemini-cli14
codex14
opencode14
antigravity13
github-copilot11