NYC
skills/smithery/ai/pytorch-fsdp2

pytorch-fsdp2

SKILL.md

Skill: Use PyTorch FSDP2 (fully_shard) correctly in a training script

This skill teaches a coding agent how to add PyTorch FSDP2 to a training loop with correct initialization, sharding, mixed precision/offload configuration, and checkpointing.

FSDP2 in PyTorch is exposed primarily via torch.distributed.fsdp.fully_shard and the FSDPModule methods it adds in-place to modules. See: references/pytorch_fully_shard_api.md, references/pytorch_fsdp2_tutorial.md.


When to use this skill

Use FSDP2 when:

  • Your model doesn’t fit on one GPU (parameters + gradients + optimizer state).
  • You want an eager-mode sharding approach that is DTensor-based per-parameter sharding (more inspectable, simpler sharded state dicts) than FSDP1.
  • You may later compose DP with Tensor Parallel using DeviceMesh.

Avoid (or be careful) if:

  • You need strict backwards-compatible checkpoints across PyTorch versions (DCP warns against this).
  • You’re forced onto older PyTorch versions without the FSDP2 stack.

Alternatives (when FSDP2 is not the best fit)

  • DistributedDataParallel (DDP): Use the standard data-parallel wrapper when you want classic distributed data parallel training.
  • FullyShardedDataParallel (FSDP1): Use the original FSDP wrapper for parameter sharding across data-parallel workers.

Reference: references/pytorch_ddp_notes.md, references/pytorch_fsdp1_api.md.


Contract the agent must follow

  1. Launch with torchrun and set the CUDA device per process (usually via LOCAL_RANK).
  2. Apply fully_shard() bottom-up, i.e., shard submodules (e.g., Transformer blocks) before the root module.
  3. Call model(input), not model.forward(input), so the FSDP2 hooks run (unless you explicitly unshard() or register the forward method).
  4. Create the optimizer after sharding and make sure it is built on the DTensor parameters (post-fully_shard).
  5. Checkpoint using Distributed Checkpoint (DCP) or the distributed-state-dict helpers, not naïve torch.save(model.state_dict()) unless you deliberately gather to full tensors.

(Each of these rules is directly described in the official API docs/tutorial; see references.)


Step-by-step procedure

0) Version & environment sanity

  • Prefer a recent stable PyTorch where the docs show FSDP2 and DCP updated recently.
  • Use torchrun --nproc_per_node <gpus_per_node> ... and ensure RANK, WORLD_SIZE, LOCAL_RANK are visible.

Reference: references/pytorch_fsdp2_tutorial.md (launch commands and setup), references/pytorch_fully_shard_api.md (user contract).


1) Initialize distributed and set device

Minimal, correct pattern:

  • dist.init_process_group(backend="nccl")
  • torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
  • Optionally create a DeviceMesh to describe the data-parallel group(s)

Reference: references/pytorch_device_mesh_tutorial.md (why DeviceMesh exists & how it manages process groups).


2) Build model on meta device (recommended for very large models)

For big models, initialize on meta, apply sharding, then materialize weights on GPU:

  • with torch.device("meta"): model = ...
  • apply fully_shard(...) on submodules, then fully_shard(model)
  • model.to_empty(device="cuda")
  • model.reset_parameters() (or your init routine)

Reference: references/pytorch_fsdp2_tutorial.md (migration guide shows this flow explicitly).


3) Apply fully_shard() bottom-up (wrapping policy = “apply where needed”)

Do not only call fully_shard on the topmost module.

Recommended sharding pattern for transformer-like models:

  • iterate modules, if isinstance(m, TransformerBlock): fully_shard(m, ...)
  • then fully_shard(model, ...)

Why:

  • fully_shard forms “parameter groups” for collective efficiency and excludes params already grouped by earlier calls. Bottom-up gives better overlap and lower peak memory.

Reference: references/pytorch_fully_shard_api.md (bottom-up requirement and why).


4) Configure reshard_after_forward for memory/perf trade-offs

Default behavior:

  • None means True for non-root modules and False for root modules (good default).

Heuristics:

  • If you’re memory-bound: keep defaults or force True on many blocks.
  • If you’re throughput-bound and can afford memory: consider keeping unsharded params longer (root often False).
  • Advanced: use an int to reshard to a smaller mesh after forward (e.g., intra-node) if it’s a meaningful divisor.

Reference: references/pytorch_fully_shard_api.md (full semantics).


5) Mixed precision & offload (optional but common)

FSDP2 uses:

  • mp_policy=MixedPrecisionPolicy(param_dtype=..., reduce_dtype=..., output_dtype=..., cast_forward_inputs=...)
  • offload_policy=CPUOffloadPolicy() if you want CPU offload

Rules of thumb:

  • Start with BF16 parameters/reductions on H100/A100-class GPUs (if numerically stable for your model).
  • Keep reduce_dtype aligned with your gradient reduction expectations.
  • If you use CPU offload, budget for PCIe/NVLink traffic and runtime overhead.

Reference: references/pytorch_fully_shard_api.md (MixedPrecisionPolicy / OffloadPolicy classes).


6) Optimizer, gradient clipping, accumulation

  • Create the optimizer after sharding so it holds DTensor params.
  • If you need gradient accumulation / no_sync:
    • use the FSDP2 mechanism (set_requires_gradient_sync) instead of FSDP1’s no_sync().

Gradient clipping:

  • Use the approach shown in the FSDP2 tutorial (“Gradient Clipping and Optimizer with DTensor”), because parameters/gradients are DTensors.

Reference: references/pytorch_fsdp2_tutorial.md.


7) Checkpointing: prefer DCP or distributed state dict helpers

Two recommended approaches:

A) Distributed Checkpoint (DCP) — best default

  • DCP saves/loads from multiple ranks in parallel and supports load-time resharding.
  • DCP produces multiple files (often at least one per rank) and operates “in place”.

B) Distributed state dict helpers

  • get_model_state_dict / set_model_state_dict with StateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True, ...)
  • For optimizer: get_optimizer_state_dict / set_optimizer_state_dict

Avoid:

  • Saving DTensor state dicts with plain torch.save unless you intentionally convert with DTensor.full_tensor() and manage memory carefully.

References:

  • references/pytorch_dcp_overview.md (DCP behavior and caveats)
  • references/pytorch_dcp_recipe.md and references/pytorch_dcp_async_recipe.md (end-to-end usage)
  • references/pytorch_fsdp2_tutorial.md (DTensor vs DCP state-dict flows)
  • references/pytorch_examples_fsdp2.md (working checkpoint scripts)

Workflow checklists (copy-paste friendly)

Workflow A: Retrofit FSDP2 into an existing training script

  • Launch with torchrun and initialize the process group.
  • Set the CUDA device from LOCAL_RANK; create a DeviceMesh if you need multi-dim parallelism.
  • Build the model (use meta if needed), apply fully_shard bottom-up, then fully_shard(model).
  • Create the optimizer after sharding so it captures DTensor parameters.
  • Use model(inputs) so hooks run; use set_requires_gradient_sync for accumulation.
  • Add DCP save/load via torch.distributed.checkpoint helpers.

Reference: references/pytorch_fsdp2_tutorial.md, references/pytorch_fully_shard_api.md, references/pytorch_device_mesh_tutorial.md, references/pytorch_dcp_recipe.md.

Workflow B: Add DCP save/load (minimal pattern)

  • Wrap state in Stateful or assemble state via get_state_dict.
  • Call dcp.save(...) from all ranks to a shared path.
  • Call dcp.load(...) and restore with set_state_dict.
  • Validate any resharding assumptions when loading into a different mesh.

Reference: references/pytorch_dcp_recipe.md.

Debug checklist (what the agent should check first)

  1. All ranks on distinct GPUs?
    If not, verify torch.cuda.set_device(LOCAL_RANK) and your torchrun flags.
  2. Did you accidentally call forward() directly?
    Use model(input) or explicitly unshard() / register forward.
  3. Is fully_shard() applied bottom-up?
    If only root is sharded, expect worse memory/perf and possible confusion.
  4. Optimizer created at the right time?
    Must be built on DTensor parameters after sharding.
  5. Checkpointing path consistent?
    • If using DCP, don’t mix with ad-hoc torch.save unless you understand conversions.
    • Be mindful of PyTorch-version compatibility warnings for DCP.

Common issues and fixes

  • Forward hooks not running → Call model(inputs) (or unshard() explicitly) instead of model.forward(...).
  • Optimizer sees non-DTensor params → Create optimizer after all fully_shard calls.
  • Only root module sharded → Apply fully_shard bottom-up on submodules before the root.
  • Memory spikes after forward → Set reshard_after_forward=True for more modules.
  • Gradient accumulation desync → Use set_requires_gradient_sync instead of FSDP1’s no_sync().

Reference: references/pytorch_fully_shard_api.md, references/pytorch_fsdp2_tutorial.md.


Minimal reference implementation outline (agent-friendly)

The coding agent should implement a script with these labeled blocks:

  • init_distributed(): init process group, set device
  • build_model_meta(): model on meta, apply fully_shard, materialize weights
  • build_optimizer(): optimizer created after sharding
  • train_step(): forward/backward/step with model(inputs) and DTensor-aware patterns
  • checkpoint_save/load(): DCP or distributed state dict helpers

Concrete examples live in references/pytorch_examples_fsdp2.md and the official tutorial reference.


References

  • references/pytorch_fsdp2_tutorial.md
  • references/pytorch_fully_shard_api.md
  • references/pytorch_ddp_notes.md
  • references/pytorch_fsdp1_api.md
  • references/pytorch_device_mesh_tutorial.md
  • references/pytorch_tp_tutorial.md
  • references/pytorch_dcp_overview.md
  • references/pytorch_dcp_recipe.md
  • references/pytorch_dcp_async_recipe.md
  • references/pytorch_examples_fsdp2.md
  • references/torchtitan_fsdp_notes.md (optional, production notes)
  • references/ray_train_fsdp2_example.md (optional, integration example)
Weekly Installs
1
Repository
smithery/ai
First Seen
3 days ago
Installed on
codex1