Skip to content

Commit

Permalink
Add rotated MNIST dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
santiag0m committed Apr 14, 2022
1 parent 4e2c5c6 commit a2d8631
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 2 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
__pycache__/
.vscode/
*.pth
scripts/
scripts/
data/
1 change: 1 addition & 0 deletions lib/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torch.utils.data import random_split, Dataset

from .synthetic import SyntheticDataset
from .mnist import MNIST


def train_val_split(dataset: Dataset, val_percentage: float) -> Tuple[Dataset]:
Expand Down
20 changes: 20 additions & 0 deletions lib/datasets/mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from torchvision import transforms
from torchvision.datasets import mnist


class MNIST(mnist.MNIST):
def __init__(
self,
rotated: bool = False,
root: str = "./data",
download: bool = True,
*args,
**kwargs
):
transform = [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
if rotated:
transform.append(transforms.RandomRotation(degrees=45))
transform = transforms.Compose(transform)
super().__init__(
root=root, transform=transform, download=download, *args, **kwargs
)
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ tqdm==4.63.0
matplotlib==3.5.1
seaborn==0.11.2
-f https://download.pytorch.org/whl/torch_stable.html
torch==1.9.0+cu102
torch==1.9.0+cu102
torchvision==0.10.0+cu102

0 comments on commit a2d8631

Please sign in to comment.