Skip to main content

Overview

In distributed training, multiple processes run simultaneously. When downloading datasets, you need to ensure the download happens only once per node to avoid:
  • Race conditions and file corruption
  • Wasted bandwidth from duplicate downloads
  • Slower initialization times
SF Tensor provides the @dataDownload decorator to solve this problem.

The @dataDownload Decorator

The @dataDownload decorator ensures that data-loading operations only execute on the primary CPU of each node (LOCAL_RANK=0). All other processes skip the function and wait.

Basic Usage

from sf_tensor.persist import dataDownload
import torchvision.datasets as datasets
import torchvision.transforms as transforms

@dataDownload
def download_dataset():
    """Download CIFAR-10 dataset"""
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    # This download only happens once per node
    train_dataset = datasets.CIFAR10(
        root='./data',
        train=True,
        download=True,
        transform=transform
    )

    test_dataset = datasets.CIFAR10(
        root='./data',
        train=False,
        download=True,
        transform=transform
    )

    print("Dataset downloaded successfully!")

# Call the decorated function
download_dataset()

Synchronization with Barrier

After downloading data, use a barrier to ensure all processes wait before accessing the data:
import torch.distributed as dist

# Download on rank 0 only
@dataDownload
def download_data():
    # Download logic here
    pass

download_data()

# Wait for download to complete on all processes
if dist.is_initialized():
    dist.barrier()

# Now all processes can safely access the data
Always use a barrier after @dataDownload: Without a barrier, non-rank-0 processes might try to load data before the download completes, causing errors.

DistributedSampler Setup

After loading data, use DistributedSampler to partition it across GPUs:
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist

# Create dataset (after download and barrier)
dataset = YourDataset()

# Create sampler for distributed training
sampler = DistributedSampler(
    dataset,
    num_replicas=dist.get_world_size() if dist.is_initialized() else 1,
    rank=dist.get_rank() if dist.is_initialized() else 0,
    shuffle=True,
    seed=42  # For reproducibility
)

# Create data loader
dataloader = DataLoader(
    dataset,
    batch_size=32,
    sampler=sampler,  # Use sampler instead of shuffle
    num_workers=4,
    pin_memory=True,
    drop_last=True  # Recommended for distributed training
)

# In your training loop, set epoch for proper shuffling
for epoch in range(num_epochs):
    sampler.set_epoch(epoch)  # Important for reproducible shuffling

    for batch in dataloader:
        # Training step...
        pass

Best Practices

Always call dist.barrier() after the @dataDownload decorated function to ensure all processes wait for the download to complete:
download_data()
if dist.is_initialized():
    dist.barrier()
Always use DistributedSampler to partition data across GPUs. Each process should see different data:
sampler = DistributedSampler(dataset)
loader = DataLoader(dataset, sampler=sampler)
Call sampler.set_epoch(epoch) at the start of each epoch to ensure proper shuffling:
for epoch in range(num_epochs):
    train_sampler.set_epoch(epoch)
    for batch in train_loader:
        # training...
After the initial download, set download=False when loading datasets to avoid re-checking:
# First call (with decorator)
@dataDownload
def download():
    datasets.CIFAR10(root='./data', download=True)

# Subsequent loads (all processes)
dataset = datasets.CIFAR10(root='./data', download=False)

Common Pitfalls

Forgetting the Barrier: If you don’t call dist.barrier() after downloading, non-rank-0 processes might try to load data before it’s ready, causing FileNotFoundError.
Not Using DistributedSampler: Without DistributedSampler, all processes will see the same data, effectively wasting compute resources and not achieving true data parallelism.
Using shuffle=True with DistributedSampler: Don’t use shuffle=True in DataLoader when using DistributedSampler. The sampler handles shuffling:
# Wrong
DataLoader(dataset, sampler=sampler, shuffle=True)  # Error!

# Correct
DataLoader(dataset, sampler=sampler)  # shuffle in sampler
Batch Size Scales with Number of GPUs: When using distributed training, your effective batch size is batch_size × num_gpus. If you set batch_size=32 in your DataLoader and train on 8 GPUs, your effective global batch size is 256. This affects:
  • Learning rate scaling: You typically need to scale your learning rate proportionally (e.g., if batch size increases 8×, scale LR by 8×)
  • Memory usage: Each GPU processes its local batch size (32 in this example)
  • Convergence behavior: Larger effective batch sizes change training dynamics
# If training on 8 GPUs with batch_size=32
# Effective global batch size = 32 × 8 = 256

dataloader = DataLoader(
    dataset,
    batch_size=32,  # Per-GPU batch size
    sampler=sampler
)

# Adjust learning rate accordingly
base_lr = 0.1
num_gpus = dist.get_world_size()
lr = base_lr * num_gpus  # Linear scaling rule