Skip to main content

What is Distributed Training?

Distributed training allows you to train deep learning models across multiple GPUs and machines, reducing training time. PyTorch provides two widely used strategies for distributed training:
  1. DDP (DistributedDataParallel) - Data parallelism for general models
  2. FSDP (FullyShardedDataParallel) - Sharded data parallelism for large models

When to Use Each Strategy

Use DDP When:

  • Your model fits in a single GPU’s memory
  • You want simple, efficient data parallelism - that is, processing part of the batch on each GPU

Use FSDP When:

  • Your model is too large for a single GPU
  • You’re training very large models (GPT, LLaMA, large vision models)
  • You want to maximize memory efficiency
  • You need to scale to billions of parameters

Quick Comparison

FeatureDDPFSDP
Model SizeFits in single GPUCan exceed single GPU memory
Memory EfficiencyModel replicated per GPUModel sharded across GPUs
SpeedFastest for small/medium modelsBetter for very large models
Gradient SyncAfter backward passDuring backward pass

Basic Architecture

DDP Architecture

┌─────────────┐  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐
│   GPU 0     │  │   GPU 1     │  │   GPU 2     │  │   GPU 3     │
│             │  │             │  │             │  │             │
│   Model     │  │   Model     │  │   Model     │  │   Model     │
│  (Copy 1)   │  │  (Copy 2)   │  │  (Copy 3)   │  │  (Copy 4)   │
│             │  │             │  │             │  │             │
│   Data 1    │  │   Data 2    │  │   Data 3    │  │   Data 4    │
└─────────────┘  └─────────────┘  └─────────────┘  └─────────────┘
       ↓                ↓                ↓                ↓
       └────────────────┴────────────────┴────────────────┘
                    Gradient Synchronization
Each GPU has:
  • Full copy of the model
  • Different slice of the data
  • Gradients averaged across all GPUs

FSDP Architecture

┌─────────────┐  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐
│   GPU 0     │  │   GPU 1     │  │   GPU 2     │  │   GPU 3     │
│             │  │             │  │             │  │             │
│ Model Shard │  │ Model Shard │  │ Model Shard │  │ Model Shard │
│   Layer 1   │  │   Layer 2   │  │   Layer 3   │  │   Layer 4   │
│             │  │             │  │             │  │             │
│   Data 1    │  │   Data 2    │  │   Data 3    │  │   Data 4    │
└─────────────┘  └─────────────┘  └─────────────┘  └─────────────┘
       ↑                ↑                ↑                ↑
       └────────────────┴────────────────┴────────────────┘
              Gather params across GPUs during Forward/Backward
Each GPU has:
  • Shard of the model (only part of parameters)
  • Different slice of the data
  • Parameters gathered on-demand from other GPUs during forward/backward passes - this avoids needing to fetch the parameters from RAM, which is much slower than GPU-to-GPU communication

What You’ll Learn