forked from pytorch/vision
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding Repeated Augment Sampler (pytorch#5051)
* Adding repaeted data-augument sampler * rebase on top of latest main * fix formatting * rename file * adding coode source
- Loading branch information
1 parent
47ae092
commit e250db3
Showing
2 changed files
with
66 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import math | ||
|
||
import torch | ||
import torch.distributed as dist | ||
|
||
|
||
class RASampler(torch.utils.data.Sampler): | ||
"""Sampler that restricts data loading to a subset of the dataset for distributed, | ||
with repeated augmentation. | ||
It ensures that different each augmented version of a sample will be visible to a | ||
different process (GPU). | ||
Heavily based on 'torch.utils.data.DistributedSampler'. | ||
This is borrowed from the DeiT Repo: | ||
https://github.com/facebookresearch/deit/blob/main/samplers.py | ||
""" | ||
|
||
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): | ||
if num_replicas is None: | ||
if not dist.is_available(): | ||
raise RuntimeError("Requires distributed package to be available!") | ||
num_replicas = dist.get_world_size() | ||
if rank is None: | ||
if not dist.is_available(): | ||
raise RuntimeError("Requires distributed package to be available!") | ||
rank = dist.get_rank() | ||
self.dataset = dataset | ||
self.num_replicas = num_replicas | ||
self.rank = rank | ||
self.epoch = 0 | ||
self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) | ||
self.total_size = self.num_samples * self.num_replicas | ||
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) | ||
self.shuffle = shuffle | ||
|
||
def __iter__(self): | ||
# Deterministically shuffle based on epoch | ||
g = torch.Generator() | ||
g.manual_seed(self.epoch) | ||
if self.shuffle: | ||
indices = torch.randperm(len(self.dataset), generator=g).tolist() | ||
else: | ||
indices = list(range(len(self.dataset))) | ||
|
||
# Add extra samples to make it evenly divisible | ||
indices = [ele for ele in indices for i in range(3)] | ||
indices += indices[: (self.total_size - len(indices))] | ||
assert len(indices) == self.total_size | ||
|
||
# Subsample | ||
indices = indices[self.rank : self.total_size : self.num_replicas] | ||
assert len(indices) == self.num_samples | ||
|
||
return iter(indices[: self.num_selected_samples]) | ||
|
||
def __len__(self): | ||
return self.num_selected_samples | ||
|
||
def set_epoch(self, epoch): | ||
self.epoch = epoch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters