-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdata_module.py
101 lines (89 loc) · 3.45 KB
/
data_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from . import image_dataset, image_dataset_cbis, image_dataset_cmmd
from . import pretraining_dataset
from .. import builder
class PretrainingDataModule(pl.LightningDataModule):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.dataset = pretraining_dataset.MultimodalPretrainingDataset
self.collate_fn = pretraining_dataset.multimodal_collate_fn
def train_dataloader(self):
transform = builder.build_transformation(self.cfg, "train")
dataset = self.dataset(self.cfg, split="train", transform=transform)
return DataLoader(
dataset,
pin_memory=True,
drop_last=True,
shuffle=True,
batch_size=self.cfg.train.batch_size,
num_workers=self.cfg.train.num_workers,
collate_fn=self.collate_fn,
)
def val_dataloader(self):
transform = builder.build_transformation(self.cfg, "valid")
dataset = self.dataset(self.cfg, split="valid", transform=transform)
return DataLoader(
dataset,
pin_memory=True,
drop_last=True,
shuffle=False,
collate_fn=self.collate_fn,
batch_size=self.cfg.train.batch_size,
num_workers=self.cfg.train.num_workers,
)
def test_dataloader(self):
transform = builder.build_transformation(self.cfg, "test")
dataset = self.dataset(self.cfg, split="test", transform=transform)
return DataLoader(
dataset,
pin_memory=True,
shuffle=False,
collate_fn=self.collate_fn,
batch_size=self.cfg.train.batch_size,
num_workers=self.cfg.train.num_workers,
)
class INBDataModule(pl.LightningDataModule):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
if cfg.dataset == 'inb':
self.dataset = image_dataset.INBImageDataset
elif cfg.dataset == 'cbis':
self.dataset = image_dataset_cbis.INBImageDataset
elif cfg.dataset == 'cmmd':
self.dataset = image_dataset_cmmd.INBImageDataset
# self.dataset = image_dataset_inbreast.INBImageDataset
def train_dataloader(self):
transform = builder.build_transformation(self.cfg, "train")
dataset = self.dataset(self.cfg, split="train", transform=transform)
return DataLoader(
dataset,
pin_memory=True,
drop_last=True,
shuffle=True,
batch_size=self.cfg.train.batch_size,
num_workers=self.cfg.train.num_workers,
)
def val_dataloader(self):
transform = builder.build_transformation(self.cfg, "valid")
dataset = self.dataset(self.cfg, split="valid", transform=transform)
return DataLoader(
dataset,
pin_memory=True,
drop_last=True,
shuffle=False,
batch_size=self.cfg.train.batch_size,
num_workers=self.cfg.train.num_workers,
)
def test_dataloader(self):
transform = builder.build_transformation(self.cfg, "test")
dataset = self.dataset(self.cfg, split="test", transform=transform)
return DataLoader(
dataset,
pin_memory=True,
shuffle=False,
batch_size=self.cfg.train.batch_size,
num_workers=self.cfg.train.num_workers,
)