Skip to content

adiprasad/pytorch-cifar

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

82 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Train CIFAR10 with PyTorch

Adapted from pytorch-cifar to make it Horovod compatible

I'm playing with PyTorch on the CIFAR10 dataset.

Prerequisites

  • Python 3.6+
  • PyTorch 1.0+
  • Horovod with Pytorch support

Training

# Start training with: 
python main.py

# You can manually resume the training with: 
python main.py --resume --lr=0.01

Accuracy

Model Acc.
VGG16 92.64%
ResNet18 93.02%
ResNet50 93.62%
ResNet101 93.75%
RegNetX_200MF 94.24%
RegNetY_400MF 94.29%
MobileNetV2 94.43%
ResNeXt29(32x4d) 94.73%
ResNeXt29(2x64d) 94.82%
SimpleDLA 94.89%
DenseNet121 95.04%
PreActResNet18 95.11%
DPN92 95.16%
DLA 95.47%

Multi GPU training using Horovod

Code changes

This tutorial will take you step by step through the changes required in the existing training code (main.py) to run it across multiple GPUs using Horovod.

The final script(main_horovod.py) with all the changes has been included in the repository.

1. Add Horovod import

Add the following code after from utils import progress_bar:

import horovod.torch as hvd

image (see line 16)

2. Initialize Horovod

Add the following code after args = parser.parse_args():

# Horovod: initialize Horovod.
hvd.init()

image (see line 30-31)

3. Pin GPU to be used by each process

With Horovod, usually one GPU is assigned per process to simplify distributed training across processes.

Comment out or remove the following device/GPU allocation code

device = 'cuda' if torch.cuda.is_available() else 'cpu'

Add the following code after hvd.init():

## Get the rank of current process
device = hvd.local_rank()

# Pin GPU to be used to process local rank (one GPU per process)
torch.cuda.set_device(device)

image (see line 33-37)

4. Add distributed sampler for distributed sampling across processes

For distributed training, it is efficient to have each copy((on different processes)) of the model work with mutually exclusive subsamples of the training dataset.

For this reason, we add a DistributedSampler to sample the training examples. Notice that we add the sampler as an argument to the DataLoader

Replace the following lines:-

trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2)

by

# Partition dataset among workers using DistributedSampler
train_sampler = torch.utils.data.distributed.DistributedSampler(
    trainset, shuffle=True, num_replicas=hvd.size(), rank=hvd.rank())
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, num_workers=2, sampler=train_sampler)

Similarly, we can distribute the evaluation load across processes during the validation phase

Replace the following lines:-

testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

by

# Partition val dataset among workers using DistributedSampler
val_sampler = torch.utils.data.distributed.DistributedSampler(
    testset, shuffle=True, num_replicas=hvd.size(), rank=hvd.rank())
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2, sampler=val_sampler)

image (see line 60-73)

5. Read checkpoint only on the first worker

Instead of loading the checkpoint from each worker process, it is more efficient to load the checkpoint through a single worker process(typically the root) and broadcast it to others.

This is usually done in tandem with the checkpointing (Use single processes to store checkpoints)

Replace the following code:

if args.resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    ...

with:

if args.resume and hvd.rank() == 0:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')

image (see line 102)

6. Broadcast start epoch and model parameters from first worker to all processes

As mentioned in the previous section, the checkpoint and model parameters are broadcast(from the root process) and synchronized with other processes.

Add the following lines of code after the checkpoint reading code :-

start_epoch = hvd.broadcast(torch.Tensor(1).fill_(start_epoch)[0], name="start_epoch", root_rank=0)
start_epoch = int(start_epoch)

Also add :-

hvd.broadcast_parameters(net.state_dict(), 0)

image (see line 111-114)

7. Adjust learning rate and add Distributed Optimizer

Horovod uses an operation that averages gradients across workers. Gradient averaging typically requires a corresponding increase in learning rate to make bigger steps in the direction of a higher-quality gradient.

Replace optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) with:

optimizer = optim.SGD(net.parameters(), lr=args.lr * hvd.size(),
                      momentum=0.9, weight_decay=5e-4)

## Add distributed optimizer
optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=net.named_parameters())

image (see line 117-121)

8. Broadcast optimizer state from first worker to synchronize the optimizer across processes

Add the following line after scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

hvd.broadcast_optimizer_state(optimizer, 0)

image (see line 125)

9. Aggregate losses and predictions across processes to calculate overall loss and accuracy

As mentioned in Section 4, every process works on their own subsample of the training set. We don't want every process to report their own training progress but rather have a single process report the aggregate training metrics.

For reporting aggregate metrics, we need to average them across all the processes.

Add the following lines after correct += predicted.eq(targets).sum().item()

train_loss_sum_across_batches_multiprocess = hvd.allreduce(torch.Tensor(1).fill_(train_loss)[0],
                                                           name="train_loss_sum_multiprocess", op=Sum)
total_sum_across_batches_multiprocess = hvd.allreduce(torch.Tensor(1).fill_(total)[0],
                                                      name="total_sum_multiprocess", op=Sum)
correct_sum_across_batches_multiprocess = hvd.allreduce(torch.Tensor(1).fill_(correct)[0],
                                                        name="correct_sum_multiprocess", op=Sum)

These lines are summing up the per process loss, number of total examples and number of correct examples respectively and storing them into new variables, which will later be used for reporting.

To enable reports from a single(root) process only, replace the following code :-

progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

with

    if hvd.local_rank() == 0:
        progress_bar(batch_idx, len(trainloader), '[Train] Average(all procs) Loss : %.3f | Average(all procs) Acc: %.3f%% (%d/%d)'
                % (train_loss_sum_across_batches_multiprocess/((batch_idx+1) * hvd.size()),
                100.*correct_sum_across_batches_multiprocess/total_sum_across_batches_multiprocess,
                correct_sum_across_batches_multiprocess, total_sum_across_batches_multiprocess))

image (see line 148-159)

Note that similar aggregation is also performed for validation (See line 179-190 inside test function)

Run

Assuming the libraries mentioned as pre-requisites are installed in your python environment :-

horovodrun -np <num-gpus> python main_horovod.py <args>

About

95.47% on CIFAR10 with PyTorch

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%