Skip to content

Commit

Permalink
Adding Repeated Augment Sampler (pytorch#5051)
Browse files Browse the repository at this point in the history
* Adding repaeted data-augument sampler

* rebase on top of latest main

* fix formatting

* rename file

* adding coode source
  • Loading branch information
yiwen-song authored Dec 8, 2021
1 parent 47ae092 commit e250db3
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 1 deletion.
60 changes: 60 additions & 0 deletions references/classification/sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import math

import torch
import torch.distributed as dist


class RASampler(torch.utils.data.Sampler):
"""Sampler that restricts data loading to a subset of the dataset for distributed,
with repeated augmentation.
It ensures that different each augmented version of a sample will be visible to a
different process (GPU).
Heavily based on 'torch.utils.data.DistributedSampler'.
This is borrowed from the DeiT Repo:
https://github.com/facebookresearch/deit/blob/main/samplers.py
"""

def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available!")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available!")
rank = dist.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
self.shuffle = shuffle

def __iter__(self):
# Deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
if self.shuffle:
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))

# Add extra samples to make it evenly divisible
indices = [ele for ele in indices for i in range(3)]
indices += indices[: (self.total_size - len(indices))]
assert len(indices) == self.total_size

# Subsample
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples

return iter(indices[: self.num_selected_samples])

def __len__(self):
return self.num_selected_samples

def set_epoch(self, epoch):
self.epoch = epoch
7 changes: 6 additions & 1 deletion references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torchvision
import transforms
import utils
from references.classification.sampler import RASampler
from torch import nn
from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode
Expand Down Expand Up @@ -172,7 +173,10 @@ def load_data(traindir, valdir, args):

print("Creating data loaders")
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
if args.ra_sampler:
train_sampler = RASampler(dataset, shuffle=True)
else:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
else:
train_sampler = torch.utils.data.RandomSampler(dataset)
Expand Down Expand Up @@ -481,6 +485,7 @@ def get_args_parser(add_help=True):
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
)
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
parser.add_argument("--ra-sampler", action="store_true", help="whether to use ra_sampler in training")

# Prototype models only
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
Expand Down

0 comments on commit e250db3

Please sign in to comment.