A single‑node mini‑FSDP wrapper for PyTorch that implements the core ideas of Fully Sharded Data Parallel in ~200 lines:
- Per‑module parameter flattening
- All‑gather full params before forward; discard back to local shard after
- Reduce‑scatter gradients during backward
- Optional AMP fp16
Built for clarity. No CPU offload, no optimizer sharding, no elastic tricks.
pip install -e .
GPUS=2 bash scripts/launch.sh
from shardspark import FSDP, FSDPConfig
wrapped = FSDP(model, FSDPConfig(fp16=True))
- FSDP papers/blogs are great, but the reference impl is complex. This repo bridges the mental gap.
- Use it to prototype or to teach your team the mechanics before moving to production FSDP.
- Single‑node NCCL only
- Requires that total param count is divisible by
world_size
- No optimizer state sharding (OSS)
- Optimizer state sharding
- Activation checkpoint helper
- Async all‑gather prefetch
MIT