Skip to content

Mini-FSDP for PyTorch. Minimal single-node Fully Sharded Data Parallel wrapper with param flattening, grad reduce-scatter, AMP, and tiny GPT/BERT training examples.

License

Notifications You must be signed in to change notification settings

salma2vec/shardspark

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Shardspark

A single‑node mini‑FSDP wrapper for PyTorch that implements the core ideas of Fully Sharded Data Parallel in ~200 lines:

  • Per‑module parameter flattening
  • All‑gather full params before forward; discard back to local shard after
  • Reduce‑scatter gradients during backward
  • Optional AMP fp16

Built for clarity. No CPU offload, no optimizer sharding, no elastic tricks.

Quickstart

pip install -e .
GPUS=2 bash scripts/launch.sh

API

from shardspark import FSDP, FSDPConfig
wrapped = FSDP(model, FSDPConfig(fp16=True))

Why this exists

  • FSDP papers/blogs are great, but the reference impl is complex. This repo bridges the mental gap.
  • Use it to prototype or to teach your team the mechanics before moving to production FSDP.

Safety & Limits

  • Single‑node NCCL only
  • Requires that total param count is divisible by world_size
  • No optimizer state sharding (OSS)

Roadmap

  • Optimizer state sharding
  • Activation checkpoint helper
  • Async all‑gather prefetch

License

MIT

About

Mini-FSDP for PyTorch. Minimal single-node Fully Sharded Data Parallel wrapper with param flattening, grad reduce-scatter, AMP, and tiny GPT/BERT training examples.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published