Skip to content

Commit

Permalink
Data refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
Flegyas committed Mar 2, 2021
1 parent 870cf21 commit c2a546e
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 121 deletions.
4 changes: 3 additions & 1 deletion .env.template
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
export MY_DATASET_PATH="/home/jonny/datasets/blues"
export YOUR_TRAIN_DATASET_PATH="/your/project/root/data/blues/train"
export YOUR_VAL_DATASET_PATH="/your/project/root/data/blues/val"
export YOUR_TEST_DATASET_PATH="/your/project/root/data/blues/test"
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ Generic template to bootstrap your [PyTorch](https://pytorch.org/get-started/loc
├── README.md
├── requirements.txt # basic requirements
└── src
├── common # common python modules
├── pl_datamodules # pytorch lightning datamodules
├── pl_modules # pytorch lightning modules
├── common # common Python modules
├── pl_data # PyTorch Lightning datamodules and datasets
├── pl_modules # PyTorch Lightning modules
└── run.py # entry point to run current conf
```

Expand Down
23 changes: 12 additions & 11 deletions conf/data/default_data.yaml
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
# @package _group_

datamodule:
_target_: src.pl_datamodules.datamodule.MyDataModule

val_percentage: 0.15
_target_: src.pl_data.datamodule.MyDataModule

datasets:
train:
_target_: src.pl_datamodules.datamodule.MyDataset
name: MNSIT
train: True
path: ${env:MNIST}
_target_: src.pl_data.datamodule.MyDataset
name: YourTrainDatasetName
path: ${env:YOUR_TRAIN_DATASET_PATH}

val:
- _target_: src.pl_data.datamodule.MyDataset
name: YourValDatasetName
path: ${env:YOUR_VAL_DATASET_PATH}

test:
- _target_: src.pl_datamodules.datamodule.MyDataset
name: MNIST
train: False
path: ${env:MNIST}
- _target_: src.pl_data.datamodule.MyDataset
name: YourTestDatasetName
path: ${env:YOUR_TEST_DATASET_PATH}

num_workers:
train: 8
Expand Down
File renamed without changes.
82 changes: 82 additions & 0 deletions src/pl_data/datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from abc import abstractmethod
from typing import Optional, Sequence

import hydra
import pytorch_lightning as pl
from omegaconf import DictConfig
from torch.utils.data import DataLoader, Dataset


class MyDataModule(pl.LightningDataModule):
def __init__(
self,
datasets: DictConfig,
num_workers: DictConfig,
batch_size: DictConfig,
cfg: DictConfig,
):
super().__init__()
self.cfg = cfg

self.datasets = datasets
self.num_workers = num_workers
self.batch_size = batch_size

self.train_dataset: Optional[Dataset] = None
self.val_datasets: Optional[Sequence[Dataset]] = None
self.test_datasets: Optional[Sequence[Dataset]] = None

def prepare_data(self) -> None:
# download only
pass

def setup(self, stage: Optional[str] = None):
# Here you should instantiate your datasets, you may also split the train into train and validation if needed.
if stage is None or stage == "fit":
self.train_dataset = hydra.utils.instantiate(self.datasets.train, cfg=self.cfg)
self.val_datasets = [
hydra.utils.instantiate(dataset_cfg, cfg=self.cfg) for dataset_cfg in self.datasets.val
]

if stage is None or stage == "test":
self.test_datasets = [
hydra.utils.instantiate(dataset_cfg, cfg=self.cfg) for dataset_cfg in self.datasets.test
]

def train_dataloader(self) -> DataLoader:
return DataLoader(
self.train_dataset,
shuffle=True,
batch_size=self.batch_size.train,
num_workers=self.num_workers.train,
)

def val_dataloader(self) -> Sequence[DataLoader]:
return [
DataLoader(
dataset,
shuffle=False,
batch_size=self.batch_size.val,
num_workers=self.num_workers.val,
)
for dataset in self.test_datasets
]

def test_dataloader(self) -> Sequence[DataLoader]:
return [
DataLoader(
dataset,
shuffle=False,
batch_size=self.batch_size.test,
num_workers=self.num_workers.test,
)
for dataset in self.test_datasets
]

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"{self.datasets=}, "
f"{self.num_workers=}, "
f"{self.batch_size=})"
)
27 changes: 27 additions & 0 deletions src/pl_data/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Union, Dict, Tuple

import torch
from omegaconf import ValueNode, DictConfig
from torch.utils.data import Dataset


class MyDataset(Dataset):
def __init__(
self, name: ValueNode, path: ValueNode, train: bool, cfg: DictConfig, **kwargs
):
super().__init__()
self.cfg = cfg
self.path = path
self.name = name
self.train = train

def __len__(self) -> int:
raise NotImplementedError

def __getitem__(
self, index
) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
raise NotImplementedError

def __repr__(self) -> str:
return f"MyDataset({self.name=}, {self.path=})"
106 changes: 0 additions & 106 deletions src/pl_datamodules/datamodule.py

This file was deleted.

0 comments on commit c2a546e

Please sign in to comment.