FullyShardedDataParallel (FSDP) is PyTorch’s memory-efficient distributed training strategy. Instead of replicating the full model on each GPU (like DDP), FSDP:
Shards model parameters, gradients, and optimizer states across GPUs
Gathers parameters only when needed during forward/backward passes
Discards parameters after use to save memory
Enables training models much larger than a single GPU’s memory
This patter reduces the need for fetching parameters from RAM, which is much slower than GPU-to-GPU data transfer.
Here’s the minimal code to use FSDP with SF Tensor:
Copy
import torchfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDPimport sf_tensor as sftif __name__ == "__main__": # 1. Initialize distributed training sft.initialize_distributed_training() # 2. Get device device = sft.get_device() # 3. Create model on CPU first (important for large models) model = YourLargeModel() # 4. Wrap model in FSDP (will move to GPU) model = FSDP(model, device_id=device) # 5. Train as usual for epoch in range(num_epochs): for batch in train_loader: loss = train_step(model, batch, device) sft.log(f"Epoch {epoch}, Loss: {loss.item():.4f}")
from torch.distributed.fsdp import FullStateDictConfig, StateDictTypeimport torch.distributed as dist# Only save on rank 0if dist.get_rank() == 0: # Configure full state dict saving save_policy = FullStateDictConfig( offload_to_cpu=True, rank0_only=True ) # Get full state dict on rank 0 with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy): state_dict = model.state_dict() # Save checkpoint torch.save({ 'epoch': epoch, 'model_state_dict': state_dict, 'optimizer_state_dict': optimizer.state_dict(), }, 'checkpoint.pth') sft.log("Checkpoint saved")
For large models, create the model on CPU first, then let FSDP move it to GPU:
Copy
# Good (model created on CPU)model = LargeModel()model = FSDP(model, device_id=device)# Bad (may cause OOM)model = LargeModel().to(device)model = FSDP(model)
Use BF16 for Large Models
BF16 is more stable than FP16 for large models:
Copy
mixed_precision = MixedPrecision( param_dtype=torch.bfloat16, # Better than FP16 reduce_dtype=torch.bfloat16, buffer_dtype=torch.bfloat16,)
Choose Appropriate Auto-Wrap Policy
Wrap at the right granularity for your model:
Copy
# For transformers, wrap individual layersauto_wrap_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls={TransformerEncoderLayer})# For other models, wrap based on sizeauto_wrap_policy = functools.partial( size_based_auto_wrap_policy, min_num_params=100_000_000)
Use HYBRID_SHARD for Multi-Node
For multi-node training, HYBRID_SHARD reduces cross-node communication:
Copy
model = FSDP( model, sharding_strategy=ShardingStrategy.HYBRID_SHARD)
Always Use DistributedSampler
Just like DDP, FSDP requires DistributedSampler:
Copy
sampler = DistributedSampler(dataset)loader = DataLoader(dataset, sampler=sampler)# Set epoch for proper shufflingfor epoch in range(num_epochs): sampler.set_epoch(epoch)
We deal with using torchrun to launch the training script for you when you use the Tensor Cloud.If you’re training elsewhere, you’ll need to use torchrun yourself. Lei Mao’s guide to distributed training has hands-on info about using torchrun.