Skip to main content

What is DDP?

DistributedDataParallel (DDP) is PyTorch’s recommended module for multi-GPU training. It works as follows:
  1. Replicates the model on each GPU
  2. You feed each process a different part of the batch (e.g., with DistributedSampler)
  3. Computes gradients independently on each GPU
  4. Averages gradients across all GPUs
  5. Updates model parameters synchronously

When to Use DDP

Use DDP for:
  • Models that fit in a single GPU’s memory
Don’t use DDP for:
  • Models too large for a single GPU (use FSDP instead)

Basic Setup with SF Tensor

Here’s minimal code to use DDP with the convience functions provided by the Python library:
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import sf_tensor as sft

if __name__ == "__main__":
    # 1. Initialize distributed training
    sft.initialize_distributed_training()

    # 2. Get device for this process
    device = sft.get_device()

    # 3. Create model and move to device
    model = YourModel()
    model = model.to(device)

    # 4. Wrap model in DDP
    model = DDP(model, device_ids=[device])

    # 5. Train as usual
    for epoch in range(num_epochs):
        for batch in train_loader:
            loss = train_step(model, batch, device)
            sft.log(f"Epoch {epoch}, Loss: {loss.item():.4f}")

Complete Training Example

Lei Mao has a good full example of distributed training. Note that, when using the Tensor Cloud, we run the torchrun commands for you automatically, and you can use our convenience functions if you wish.

Key DDP Parameters

When wrapping your model in DDP, you can configure several parameters:
from torch.nn.parallel import DistributedDataParallel as DDP

model = DDP(
    model,
    device_ids=[device],           # GPU to use for this process
    output_device=device,          # Where to place output (default: device_ids[0])
    broadcast_buffers=True,        # Sync buffers (e.g., BatchNorm) at forward
    find_unused_parameters=False,  # Set True if some params don't receive gradients
    gradient_as_bucket_view=False  # Memory optimization (experimental)
)

Important Parameters

Specifies which GPU(s) to use for this process. Usually a single device obtained from sft.get_device().
device = sft.get_device()
model = DDP(model, device_ids=[device])
Controls whether to synchronize buffers (like BatchNorm running statistics) during forward pass.
  • True (default): Buffers synced at every forward pass
  • False: Buffers not synced (faster, but may affect BatchNorm)
# Keep default for most cases
model = DDP(model, device_ids=[device], broadcast_buffers=True)
Set to True if some model parameters don’t receive gradients (e.g., in conditional architectures).
  • False (default): Assumes all params get gradients
  • True: Handles unused params (slower)
# Use for models with conditional paths
model = DDP(model, device_ids=[device], find_unused_parameters=True)
Only set find_unused_parameters=True if you actually have unused parameters. It slows down training.

Saving and Loading Models

When using DDP, the model is wrapped, so you need to access the underlying module:

Saving

import torch
import torch.distributed as dist

# Save only on rank 0 to avoid multiple writes
if dist.get_rank() == 0 if dist.is_initialized() else True:
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.module.state_dict(),  # Note: model.module
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, 'checkpoint.pth')
Use model.module.state_dict() instead of model.state_dict() to save the underlying model without DDP wrapper.

Loading

import torch

# Load checkpoint
checkpoint = torch.load('checkpoint.pth', map_location=device)

# Create model and wrap in DDP
model = YourModel().to(device)
model = DDP(model, device_ids=[device])

# Load state dict into the underlying module
model.module.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

Performance Tips

1. Use Gradient Accumulation

For larger effective batch sizes without increasing memory:
accumulation_steps = 4

for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)

    output = model(data)
    loss = criterion(output, target)
    loss = loss / accumulation_steps  # Scale loss

    loss.backward()  # Accumulate gradients

    if (batch_idx + 1) % accumulation_steps == 0:
        optimizer.step()  # Update every N steps
        optimizer.zero_grad()

2. Use Mixed Precision Training

Reduce memory usage and speed up training:
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for data, target in train_loader:
    data, target = data.to(device), target.to(device)

    optimizer.zero_grad()

    # Forward with autocast
    with autocast():
        output = model(data)
        loss = criterion(output, target)

    # Scaled backward
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

3. Pin Memory and Non-Blocking Transfer

# DataLoader with pin_memory
train_loader = DataLoader(
    dataset,
    batch_size=128,
    sampler=train_sampler,
    num_workers=4,
    pin_memory=True  # Faster CPU-to-GPU transfer
)

# Non-blocking transfer in training loop
for data, target in train_loader:
    data = data.to(device, non_blocking=True)
    target = target.to(device, non_blocking=True)

Common Issues and Solutions

Cause: DDP expects all parameters to receive gradients in every backward pass.Solution: Set find_unused_parameters=True if some parameters don’t always get gradients:
model = DDP(model, device_ids=[device], find_unused_parameters=True)
Cause: Not calling backward() on all processes or using different computational graphs.Solution: Ensure all processes call backward() with the same model structure:
# All processes must execute
loss = criterion(output, target)
loss.backward()
Cause: Different random seeds or not using DistributedSampler.Solution: Set seeds and use DistributedSampler:
import random
import numpy as np

# Set seeds
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

# Use DistributedSampler
sampler = DistributedSampler(dataset, shuffle=True, seed=42)
Cause: BatchNorm statistics not synchronized across GPUs.Solution: Use SyncBatchNorm for synchronized batch normalization:
from torch.nn import SyncBatchNorm

# Convert all BatchNorm layers to SyncBatchNorm
model = SyncBatchNorm.convert_sync_batchnorm(model)
model = model.to(device)
model = DDP(model, device_ids=[device])
Prefer DistributedDataParallel (DDP) over nn.DataParallel for multi‑GPU training — PyTorch recommends DDP and it’s significantly faster and more scalable. (nn.DataParallel remains available but is not recommended.)