-
Notifications
You must be signed in to change notification settings - Fork 155
/
data.py
86 lines (74 loc) · 2.69 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import zipfile
import pytorch_lightning as pl
import requests
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.datasets import CIFAR10
from tqdm import tqdm
class CIFAR10Data(pl.LightningDataModule):
def __init__(self, args):
super().__init__()
self.hparams = args
self.mean = (0.4914, 0.4822, 0.4465)
self.std = (0.2471, 0.2435, 0.2616)
def download_weights():
url = (
"https://rutgers.box.com/shared/static/gkw08ecs797j2et1ksmbg1w5t3idf5r5.zip"
)
# Streaming, so we can iterate over the response.
r = requests.get(url, stream=True)
# Total size in Mebibyte
total_size = int(r.headers.get("content-length", 0))
block_size = 2 ** 20 # Mebibyte
t = tqdm(total=total_size, unit="MiB", unit_scale=True)
with open("state_dicts.zip", "wb") as f:
for data in r.iter_content(block_size):
t.update(len(data))
f.write(data)
t.close()
if total_size != 0 and t.n != total_size:
raise Exception("Error, something went wrong")
print("Download successful. Unzipping file...")
path_to_zip_file = os.path.join(os.getcwd(), "state_dicts.zip")
directory_to_extract_to = os.path.join(os.getcwd(), "cifar10_models")
with zipfile.ZipFile(path_to_zip_file, "r") as zip_ref:
zip_ref.extractall(directory_to_extract_to)
print("Unzip file successful!")
def train_dataloader(self):
transform = T.Compose(
[
T.RandomCrop(32, padding=4),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize(self.mean, self.std),
]
)
dataset = CIFAR10(root=self.hparams.data_dir, train=True, transform=transform)
dataloader = DataLoader(
dataset,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
shuffle=True,
drop_last=True,
pin_memory=True,
)
return dataloader
def val_dataloader(self):
transform = T.Compose(
[
T.ToTensor(),
T.Normalize(self.mean, self.std),
]
)
dataset = CIFAR10(root=self.hparams.data_dir, train=False, transform=transform)
dataloader = DataLoader(
dataset,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
drop_last=True,
pin_memory=True,
)
return dataloader
def test_dataloader(self):
return self.val_dataloader()