Skip to main content

What is FSDP?

FullyShardedDataParallel (FSDP) is PyTorch’s memory-efficient distributed training strategy. Instead of replicating the full model on each GPU (like DDP), FSDP:
  1. Shards model parameters, gradients, and optimizer states across GPUs
  2. Gathers parameters only when needed during forward/backward passes
  3. Discards parameters after use to save memory
  4. Enables training models much larger than a single GPU’s memory
This patter reduces the need for fetching parameters from RAM, which is much slower than GPU-to-GPU data transfer.

When to Use FSDP

Use FSDP for:
  • Models too large to fit on a single GPU
Don’t use FSDP for:
  • Small models that fit in GPU memory (use DDP instead - it’s faster)

FSDP vs. DDP

AspectDDPFSDP
Model SizeMust fit in single GPUCan exceed single GPU memory
Memory per GPUFull model + gradientsModel shard + gradients shard

Basic Setup with SF Tensor

Here’s the minimal code to use FSDP with SF Tensor:
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
import sf_tensor as sft

if __name__ == "__main__":
    # 1. Initialize distributed training
    sft.initialize_distributed_training()

    # 2. Get device
    device = sft.get_device()

    # 3. Create model on CPU first (important for large models)
    model = YourLargeModel()

    # 4. Wrap model in FSDP (will move to GPU)
    model = FSDP(model, device_id=device)

    # 5. Train as usual
    for epoch in range(num_epochs):
        for batch in train_loader:
            loss = train_step(model, batch, device)
            sft.log(f"Epoch {epoch}, Loss: {loss.item():.4f}")

FSDP Configuration Options

FSDP has many configuration options. Here are the most important ones:

1. Sharding Strategy

Controls how model parameters are distributed across GPUs:
from torch.distributed.fsdp import ShardingStrategy

# Option 1: Full Sharding (most memory efficient)
model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.FULL_SHARD  # Default
)

# Option 2: Shard Grad Op (shard gradients and optimizer states only)
model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.SHARD_GRAD_OP
)

# Option 3: No Sharding (equivalent to DDP)
model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.NO_SHARD
)

# Option 4: Hybrid Sharding (shard within node, replicate across nodes)
model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.HYBRID_SHARD
)
Only shards gradients and optimizer states. Parameters are replicated.
  • Memory: Medium efficiency
  • Speed: Faster than FULL_SHARD
  • Use: When model fits but gradients don’t
model = FSDP(model, sharding_strategy=ShardingStrategy.SHARD_GRAD_OP)
Shards within each node, replicates across nodes.
  • Memory: Good efficiency
  • Speed: Fast (less cross-node communication)
  • Use: Multi-node training
model = FSDP(model, sharding_strategy=ShardingStrategy.HYBRID_SHARD)
No sharding, equivalent to DDP.
  • Memory: Least efficient
  • Speed: Fastest
  • Use: Debugging or when you want DDP-like behavior
model = FSDP(model, sharding_strategy=ShardingStrategy.NO_SHARD)

2. Mixed Precision

Use mixed precision to reduce memory and speed up training:
from torch.distributed.fsdp import MixedPrecision

# FP16 mixed precision
mixed_precision_policy = MixedPrecision(
    param_dtype=torch.float16,      # Store params in FP16
    reduce_dtype=torch.float16,      # Reduce gradients in FP16
    buffer_dtype=torch.float16,      # Store buffers in FP16
)

model = FSDP(
    model,
    mixed_precision=mixed_precision_policy
)

# BF16 mixed precision (better for large models)
mixed_precision_policy = MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.bfloat16,
    buffer_dtype=torch.bfloat16,
)

model = FSDP(
    model,
    mixed_precision=mixed_precision_policy
)
BF16 (bfloat16) is often better than FP16 for large models because it has the same exponent range as FP32, reducing numerical issues.

3. Auto-Wrap Policy

Controls which modules get wrapped by FSDP:
import functools
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy

# Option 1: Size-based (wrap modules with > N parameters)
auto_wrap_policy = functools.partial(
    size_based_auto_wrap_policy,
    min_num_params=100_000_000  # 100M params
)

model = FSDP(model, auto_wrap_policy=auto_wrap_policy)

# Option 2: Transformer-based (wrap specific layer types)
from torch.nn import TransformerEncoderLayer, TransformerDecoderLayer

auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer}
)

model = FSDP(model, auto_wrap_policy=auto_wrap_policy)

# Option 3: Manual wrapping (wrap specific modules yourself)
# Don't pass auto_wrap_policy, manually wrap sub-modules
model = MyModel()
model.encoder = FSDP(model.encoder)
model.decoder = FSDP(model.decoder)
model = FSDP(model)

4. CPU Offloading

Offload parameters to CPU when not in use (extreme memory savings):
from torch.distributed.fsdp import CPUOffload

model = FSDP(
    model,
    cpu_offload=CPUOffload(offload_params=True)
)
CPU offloading saves memory but significantly slows down training due to CPU-GPU transfers. Only use when absolutely necessary.

Saving and Loading FSDP Models

FSDP models require special handling for checkpointing:

Saving Full State Dict

from torch.distributed.fsdp import FullStateDictConfig, StateDictType
import torch.distributed as dist

# Only save on rank 0
if dist.get_rank() == 0:
    # Configure full state dict saving
    save_policy = FullStateDictConfig(
        offload_to_cpu=True,
        rank0_only=True
    )

    # Get full state dict on rank 0
    with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
        state_dict = model.state_dict()

    # Save checkpoint
    torch.save({
        'epoch': epoch,
        'model_state_dict': state_dict,
        'optimizer_state_dict': optimizer.state_dict(),
    }, 'checkpoint.pth')

    sft.log("Checkpoint saved")

Loading Full State Dict

from torch.distributed.fsdp import FullStateDictConfig, StateDictType

# Load checkpoint
checkpoint = torch.load('checkpoint.pth', map_location='cpu')

# Create model and wrap in FSDP
model = YourModel()
model = FSDP(model, device_id=device)

# Configure full state dict loading
load_policy = FullStateDictConfig(
    offload_to_cpu=True,
    rank0_only=False
)

# Load state dict
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, load_policy):
    model.load_state_dict(checkpoint['model_state_dict'])

optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

sft.log("Checkpoint loaded")

Saving Sharded State Dict

For very large models, save sharded checkpoints:
from torch.distributed.fsdp import ShardedStateDictConfig, StateDictType
from torch.distributed.checkpoint import save

# Configure sharded state dict
save_policy = ShardedStateDictConfig()

with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT, save_policy):
    state_dict = model.state_dict()

# Save sharded checkpoint (creates a directory with multiple files)
save(
    state_dict={"model": state_dict},
    storage_writer=...,  # Custom storage writer
    checkpoint_id="checkpoint_dir"
)

FSDP with Activation Checkpointing

For even larger models, combine FSDP with activation checkpointing:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    checkpoint_wrapper,
    CheckpointImpl,
    apply_activation_checkpointing,
)

# Wrap model in FSDP
model = FSDP(model, device_id=device)

# Apply activation checkpointing to specific layers
def check_fn(module):
    # Checkpoint transformer layers
    return isinstance(module, TransformerEncoderLayer)

apply_activation_checkpointing(
    model,
    checkpoint_wrapper_fn=functools.partial(
        checkpoint_wrapper,
        checkpoint_impl=CheckpointImpl.NO_REENTRANT,
    ),
    check_fn=check_fn,
)
Activation checkpointing trades compute for memory by recomputing activations during backward pass.

Best Practices

For large models, create the model on CPU first, then let FSDP move it to GPU:
# Good (model created on CPU)
model = LargeModel()
model = FSDP(model, device_id=device)

# Bad (may cause OOM)
model = LargeModel().to(device)
model = FSDP(model)
BF16 is more stable than FP16 for large models:
mixed_precision = MixedPrecision(
    param_dtype=torch.bfloat16,  # Better than FP16
    reduce_dtype=torch.bfloat16,
    buffer_dtype=torch.bfloat16,
)
Wrap at the right granularity for your model:
# For transformers, wrap individual layers
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={TransformerEncoderLayer}
)

# For other models, wrap based on size
auto_wrap_policy = functools.partial(
    size_based_auto_wrap_policy,
    min_num_params=100_000_000
)
For multi-node training, HYBRID_SHARD reduces cross-node communication:
model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.HYBRID_SHARD
)
Just like DDP, FSDP requires DistributedSampler:
sampler = DistributedSampler(dataset)
loader = DataLoader(dataset, sampler=sampler)

# Set epoch for proper shuffling
for epoch in range(num_epochs):
    sampler.set_epoch(epoch)

Launching FSDP Training

We deal with using torchrun to launch the training script for you when you use the Tensor Cloud. If you’re training elsewhere, you’ll need to use torchrun yourself. Lei Mao’s guide to distributed training has hands-on info about using torchrun.