In the part I of this series, we saw how we could

  • launch different processes with mp.spawn or through multiple terminal windows
  • group processes in process groups for communication and synchronization
  • implement a distributed training loop by manually synchronizing gradients or by hooking them into the backward pass

In this part, we will now follow up on this and implement our own simplified version of DistributedDataParallel (DDP) and other helpers for distributed training – particularly distributed sampling of data. Initially I meant to also include a custom SyncBatchNorm implementation, however this requires knowledge on how to extend PyTorch’s Autograd engine, which I would first like to cover separately. The helpers and our DDP implementation will allow us to better structure our code and abstract away some of the boilerplate code that we had to use inside the training loop. We will then verify the correctness of our implementation by training a ResNet181 on the FashionMNIST2 dataset and comparing the results to training it on a single GPU.

Our own DDP

You can find the documentation of the original class here: DistributedDataParallel. We will implement a significantly simplified version of it, but the main ideas will be the same.

We can start by copying over our context manager from the previous part that sets up and destroys the process group.

import os
import contextlib

import torch.distributed as dist

MASTER_ADDR = "localhost"
MASTER_PORT = "12355"
DIST_BACKEND = "nccl"

@contextlib.contextmanager
def setup_dist(rank: int, world_size: int):
    try:
        os.environ['MASTER_ADDR'] = MASTER_ADDR
        os.environ['MASTER_PORT'] = MASTER_PORT
        dist.init_process_group(DIST_BACKEND, rank=rank, world_size=world_size)
        yield
    finally:
        dist.destroy_process_group()

Since we plan on actually training a model this time, we will use the "nccl" backend, which is optimized for GPUs. If you don’t have multiple GPUs available, you can also use the "gloo" backend like we did in the previous part. However, this will require you to modify some of the calls to all_reduce since the "gloo" backend does not support the ReduceOp.AVG operation.

Next, we can start implementing our DistributedDataParallel class. This class will simply wrap an existing nn.Module and take care of synchronizing the gradients and ensuring that all models (on all ranks) start from the same weights. Other than that, we want this wrapper to be really lightweight and behave as closesly as possible to the original module, which is why we will simply propagate the forward method and the state_dict method of the wrapped module. Since we are using the "nccl" backend, we can use the ReduceOp.AVG operation. This makes the code a bit simpler, since we do not need to explicitly pass the world size to the gradient averaging hook. The hook will simply average the gradients across all processes. If you want to use the "gloo" backend, you can use functools.partial and pass the world size to the CustomDDP constructor – the previous post already showed how to do this.

from torch.nn import Module
from torch import Tensor
import torch.distributed as dist
from torch.distributed import ReduceOp


class CustomDDP(Module):
    def __init__(self, module: Module):
        super().__init__()
        self.module = module
        for param in self.module.parameters():
            param.register_post_accumulate_grad_hook(grad_avg_hook)
            dist.broadcast(param.data, src=0)

    def forward(self, *args, **kwargs):
        return self.module(*args, **kwargs)

    def state_dict(self, *args, **kwargs):
        """Make the state_dict compatible with the unwrapped module."""
        return self.module.state_dict(*args, **kwargs)

    def load_state_dict(self, *args, **kwargs):
        return self.module.load_state_dict(*args, **kwargs)


def grad_avg_hook(param: Tensor) -> None:
    dist.all_reduce(param.grad, op=ReduceOp.AVG)

Distributing the Samples across the Ranks

The next thing we have to take care of is the distributed sampling of the dataset. If we just use our regular dataloader on all ranks, we will end up with each rank processing all the data, which is not what we want – we want each rank to only process a subset of the data, since this is the thing that will actually let us go through a single epoch faster.

PyTorch’s Sampler interface is documented here and requires each sampler to have at least an __iter__ method that returns an iterator over the indices of the samples, optionally also implementing __len__. So, while the usual sampler for a single GPU would simply return an iterator over [0, 1, 2, ..., n-1], our distributed sampler will return an iterator over disjoint subsets of the dataset for each rank.

When we don’t want to shuffle the dataset, i.e. during the validation loop, we can simply generate these indices as indices = range(self.num_samples)[rank::world_size]. However when we want to shuffle the dataset, we need to make sure that we shuffle it exactly the same way across all ranks. There are several ways of solving this:

  1. Similar to how we handled broadcasting of the model weights: Randomly shuffle the indices on rank 0 and broadcast it to the other ranks. The issue with this is that in order to use the broadcasting option we would need to allocate it on the GPU first, taking away precious GPU memory which we actually want to use for the model optimization. It is generally not encouraged to utilize the GPU during the data loading phase, especially for something that needs no computation, only consumes memory.
  2. We seed our pseudo-random number generator with a number that is generated from the epoch number. This ensures that all ranks have exactly the same shuffled indices, but the shuffling will still be different in every epoch.

For above mentioned limitations of the first approach, I will show here how to implement the second one:

class CustomDistSampler:
    def __init__(self, dataset: Dataset, shuffle: bool = True, seed: int = 0):
        self.dataset = dataset
        self.shuffle = shuffle
        self.num_samples = len(self.dataset) // dist.get_world_size()
        self.seed = seed
        self.epoch = 0

    def __iter__(self) -> Iterator[int]:
        rank = dist.get_rank()
        world_size = dist.get_world_size()

        if not self.shuffle:
            indices = range(self.num_samples)[rank::world_size]
            yield from iter(indices)
        else:
            gen = Generator()
            gen.manual_seed(self.seed + self.epoch)
            indices = torch.randperm(len(self.dataset), generator=gen)[rank::world_size]
            yield from iter(indices.tolist())

    def __len__(self):
        return self.num_samples

Relatively straightforward, right? We just need to make sure that we increment the sampler’s epoch attribute at the end of each epoch in order to get different shufflings in each epoch.

Training with our Custom DDP

We will demonstrate the correctness of our implementation by training a ResNet18 on the FashionMNIST dataset. Compared to a normal PyTorch training loop there is not much that changes:

  • We wrap our model with CustomDDP after moving it to the correct GPU.
  • As mentioned earlier, we keep PyTorch’s SyncBatchNorm implementation and don’t implement our own version of it in this post. Any PyTorch model simply be wrapped by SyncBatchNorm.convert_sync_batchnorm in order to convert all BatchNorm layers to SyncBatchNorm layers. This will ensure that the batch statistics are calculated on the whole (global) batch instead of just the local sub-batch on each rank.
  • Instead of specifying shuffle=True or shuffle=False in the dataloader, we will instead pass our custom sampler through the sampler argument.
  • Before printing the losses we also make sure to average them across all ranks, so that we can directly compare the losses printed by rank 0 to the losses of the single-GPU training.
  • Finally, we need to make sure that we set the epoch attribute of the samplers at the end of each epoch, so that we get different shufflings in each epoch.
# custom_ddp.py
import os
import contextlib
from argparse import ArgumentParser
from typing import Iterator
import time

import torch
from torch import Tensor, Generator
from torch.nn import Module, SyncBatchNorm
import torch.distributed as dist
from torch.distributed import ReduceOp
import torch.multiprocessing as mp
from torch.optim import SGD
from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from torchvision.transforms import ToTensor, Grayscale, Compose
from torchvision import models


MASTER_ADDR = "localhost"
MASTER_PORT = "12355"
DIST_BACKEND = "nccl"


@contextlib.contextmanager
def setup_dist(rank: int, world_size: int):
    try:
        os.environ["MASTER_ADDR"] = MASTER_ADDR
        os.environ["MASTER_PORT"] = MASTER_PORT
        dist.init_process_group(DIST_BACKEND, rank=rank, world_size=world_size)
        yield
    finally:
        dist.destroy_process_group()


class CustomDDP(Module):
    def __init__(self, module: Module):
        super().__init__()
        self.module = module
        for param in self.module.parameters():
            param.register_post_accumulate_grad_hook(grad_avg_hook)
            dist.broadcast(param.data, src=0)

    def forward(self, *args, **kwargs):
        return self.module(*args, **kwargs)

    def state_dict(self, *args, **kwargs):
        return self.module.state_dict(*args, **kwargs)

    def load_state_dict(self, *args, **kwargs):
        return self.module.load_state_dict(*args, **kwargs)


def grad_avg_hook(param: Tensor) -> None:
    dist.all_reduce(param.grad, op=ReduceOp.AVG)


class CustomDistSampler:
    def __init__(self, dataset: Dataset, shuffle: bool = True, seed: int = 0):
        self.dataset = dataset
        self.shuffle = shuffle
        self.num_samples = len(self.dataset) // dist.get_world_size()
        self.seed = seed
        self.epoch = 0

    def __iter__(self) -> Iterator[int]:
        rank = dist.get_rank()
        world_size = dist.get_world_size()

        if not self.shuffle:
            indices = range(self.num_samples)[rank::world_size]
            yield from iter(indices)
        else:
            gen = Generator()
            gen.manual_seed(self.seed + self.epoch)
            indices_ = torch.randperm(len(self.dataset), generator=gen)[
                rank::world_size
            ]
            yield from iter(indices_.tolist())

    def __len__(self):
        return self.num_samples


def train_dist(
    rank: int, world_size: int, num_epochs: int, batch_size: int, num_workers: int
) -> None:
    # Setup the process group.
    with setup_dist(rank, world_size):
        # Model initialization and training loop.
        model = CustomDDP(
            SyncBatchNorm.convert_sync_batchnorm(
                models.resnet18(num_classes=10).to(rank)
            )
        )
        learning_rate = 0.001
        optimizer = SGD(model.parameters(), lr=learning_rate)

        # Load FashionMNIST dataset.
        train_set = FashionMNIST(
            root="./data",
            train=True,
            download=True,
            transform=Compose([Grayscale(num_output_channels=3), ToTensor()]),
        )
        train_sampler = CustomDistSampler(train_set, shuffle=True)
        train_loader = DataLoader(
            train_set,
            batch_size=batch_size // world_size,
            sampler=train_sampler,
            num_workers=num_workers,
            pin_memory=True,
        )

        val_set = FashionMNIST(
            root="./data",
            train=False,
            download=True,
            transform=Compose([Grayscale(num_output_channels=3), ToTensor()]),
        )
        val_sampler = CustomDistSampler(val_set, shuffle=False)
        val_loader = DataLoader(
            val_set,
            batch_size=batch_size // world_size,
            sampler=val_sampler,
            num_workers=num_workers,
            pin_memory=True,
        )

        for epoch in range(num_epochs):
            epoch_train_losses = []
            for inputs, targets in train_loader:
                inputs, targets = inputs.to(rank), targets.to(rank)

                # Forward pass.
                outputs = model(inputs)
                loss = F.cross_entropy(outputs, targets)

                # Backward pass and optimization.
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                epoch_train_losses += [loss.item()]

            epoch_val_losses = []
            with torch.no_grad():
                for inputs, targets in val_loader:
                    inputs, targets = inputs.to(rank), targets.to(rank)

                    # Forward pass.
                    outputs = model(inputs)
                    loss = F.cross_entropy(outputs, targets)
                    epoch_val_losses += [loss.item()]

            epoch_train_loss = torch.tensor(
                sum(epoch_train_losses) / len(epoch_train_losses), device=rank
            )
            epoch_val_loss = torch.tensor(
                sum(epoch_val_losses) / len(epoch_val_losses), device=rank
            )
            dist.all_reduce(epoch_train_loss, op=ReduceOp.AVG)
            dist.all_reduce(epoch_val_loss, op=ReduceOp.AVG)

            train_sampler.epoch += 1
            val_sampler.epoch += 1

            # Print loss for the current iteration.
            if rank == 0:
                print(
                    f"Iteration {epoch + 1}/{num_epochs}\t"
                    f"Train Loss: {epoch_train_loss:.4f}\t"
                    f"Val Loss: {epoch_val_loss:.4f}"
                )


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--world-size", type=int, required=True)
    parser.add_argument("--batch-size", type=int, default=128)
    parser.add_argument("--num-epochs", type=int, default=10)
    parser.add_argument("--num-workers", type=int, default=4)
    args = parser.parse_args()

    start_time = time.time()
    mp.spawn(
        train_dist,
        args=(args.world_size, args.num_epochs, args.batch_size, args.num_workers),
        nprocs=args.world_size,
    )
    end_time = time.time()

    print("Completed training in {:.2f} seconds.".format(end_time - start_time))

Verifying the Correctness

Creating 100% coherent training runs between a single GPU and multiple GPUs would take a bit more engineering effort, since we would need to make sure that all sources of randomness are controlled. However, we can check how the training curves behave when we train on a single GPU and with multiple GPUs. For this, I ran the training loop from above on 2 and 4 GPUs and I created a “standard” single-GPU training loop that does not use any components required for distributed training. The outputs are shown below and show that the behaviour between the runs is consistent. And if you don’t believe me, you can also download the script below and run it yourself. 😊

Running distributed training with 2 GPUs...
Iteration 1/10  Train Loss: 0.5944      Val Loss: 0.4728
Iteration 2/10  Train Loss: 0.3944      Val Loss: 0.4090
Iteration 3/10  Train Loss: 0.3384      Val Loss: 0.3770
Iteration 4/10  Train Loss: 0.3009      Val Loss: 0.3555
Iteration 5/10  Train Loss: 0.2744      Val Loss: 0.3463
Iteration 6/10  Train Loss: 0.2497      Val Loss: 0.3527
Iteration 7/10  Train Loss: 0.2325      Val Loss: 0.3287
Iteration 8/10  Train Loss: 0.2161      Val Loss: 0.3372
Iteration 9/10  Train Loss: 0.2022      Val Loss: 0.3346
Iteration 10/10 Train Loss: 0.1876      Val Loss: 0.3350

Running distributed training with 2 GPUs...
Iteration 1/10  Train Loss: 0.5937      Val Loss: 0.4382
Iteration 2/10  Train Loss: 0.3910      Val Loss: 0.3805
Iteration 3/10  Train Loss: 0.3356      Val Loss: 0.3603
Iteration 4/10  Train Loss: 0.3019      Val Loss: 0.3261
Iteration 5/10  Train Loss: 0.2735      Val Loss: 0.3098
Iteration 6/10  Train Loss: 0.2505      Val Loss: 0.3067
Iteration 7/10  Train Loss: 0.2317      Val Loss: 0.3118
Iteration 8/10  Train Loss: 0.2144      Val Loss: 0.2965
Iteration 9/10  Train Loss: 0.2006      Val Loss: 0.2921
Iteration 10/10 Train Loss: 0.1860      Val Loss: 0.2902

Running single-GPU training...
Iteration 1/10  Train Loss: 0.5893      Val Loss: 0.4605
Iteration 2/10  Train Loss: 0.3940      Val Loss: 0.3993
Iteration 3/10  Train Loss: 0.3376      Val Loss: 0.3635
Iteration 4/10  Train Loss: 0.3005      Val Loss: 0.3473
Iteration 5/10  Train Loss: 0.2735      Val Loss: 0.3341
Iteration 6/10  Train Loss: 0.2520      Val Loss: 0.3263
Iteration 7/10  Train Loss: 0.2330      Val Loss: 0.3237
Iteration 8/10  Train Loss: 0.2176      Val Loss: 0.3175
Iteration 9/10  Train Loss: 0.2005      Val Loss: 0.3121
Iteration 10/10 Train Loss: 0.1881      Val Loss: 0.3189

  1. He, Kaiming, et al. “Deep residual learning for image recognition.” Proceedings of the IEEE conference on computer vision and pattern recognition. 2016. ↩︎

  2. Xiao, Han, Kashif Rasul, and Roland Vollgraf. “Fashion-mnist: a novel image dataset for benchmarking machine learning algorithms.” arXiv preprint arXiv:1708.07747 (2017). ↩︎