Skip to content

Commit

Permalink
docs for configs and schemas
Browse files Browse the repository at this point in the history
  • Loading branch information
ivicadimitrovski committed Jul 3, 2023
1 parent 554190e commit bfc3bdb
Show file tree
Hide file tree
Showing 5 changed files with 290 additions and 8 deletions.
17 changes: 12 additions & 5 deletions aitlas/base/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ class BaseDataset(Dataset, Configurable):
name = None

def __init__(self, config):
"""BaseDataset constructor
:param config: Configuration object which specifies the details of the dataset.
:type config: Config, contains information for the batch size, number of workers, list of labels,
list of transformations
"""

Dataset.__init__(self)
Configurable.__init__(self, config)

Expand Down Expand Up @@ -53,6 +59,7 @@ def prepare(self):
return True

def dataloader(self):
"""Create and return a dataloader for the dataset"""
return torch.utils.data.DataLoader(
self,
batch_size=self.batch_size,
Expand All @@ -69,31 +76,31 @@ def get_labels(self):
)

def show_batch(self, size):
"""Implement this if you want to return the complete set of labels of the dataset"""
"""Implement this if you want to return a random batch of images from the dataset"""
raise NotImplementedError(
"Please implement the `show_batch` method for your dataset"
)

def show_samples(self):
"""Implement this if you want to return the complete set of labels of the dataset"""
"""Implement this if you want to return a random samples from the dataset"""
raise NotImplementedError(
"Please implement the `show_samples` method for your dataset"
)

def show_image(self, index):
"""Implement this if you want to return the complete set of labels of the dataset"""
"""Implement this if you want to return an image with a given index from the dataset"""
raise NotImplementedError(
"Please implement the `show_image` method for your dataset"
)

def data_distribution_table(self):
"""Implement this if you want to return the complete set of labels of the dataset"""
"""Implement this if you want to return the label distribution of the dataset"""
raise NotImplementedError(
"Please implement the `data_distribution_table` method for your dataset"
)

def data_distribution_barchart(self):
"""Implement this if you want to return the complete set of labels of the dataset"""
"""Implement this if you want to return the label distribution of the dataset as a barchart"""
raise NotImplementedError(
"Please implement the `data_distribution_barchart` method for your dataset"
)
Expand Down
100 changes: 97 additions & 3 deletions aitlas/base/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,35 @@


class BaseDatasetSchema(Schema):

"""
Schema for configuring a base dataset.
:param batch_size: Batch size for the dataset. Default is 64.
:type batch_size: int, optional
:param shuffle: Flag indicating whether to shuffle the dataset. Default is True.
:type shuffle: bool, optional
:param num_workers: Number of workers to use for data loading. Default is 4.
:type num_workers: int, optional
:param pin_memory: Flag indicating whether to use page-locked memory. Default is False.
:type pin_memory: bool, optional
:param transforms: Classes to run transformations over the input data.
:type transforms: List[str], optional
:param target_transforms: Classes to run transformations over the target data.
:type target_transforms: List[str], optional
:param joint_transforms: Classes to run transformations over the input and target data.
:type joint_transforms: List[str], optional
:param labels: Labels for the dataset.
:type labels: List[str], optional
"""

batch_size = fields.Int(missing=64, description="Batch size", example=64)
shuffle = fields.Bool(
missing=True, description="Should shuffle dataset", example=False
Expand All @@ -11,20 +40,41 @@ class BaseDatasetSchema(Schema):
missing=False, description="Whether to use page-locked memory"
)
transforms = fields.List(
fields.String, missing=None, description="Classes to run transformations.",
fields.String, missing=None, description="Classes to run transformations over the input data.",
)
target_transforms = fields.List(
fields.String, missing=None, description="Classes to run transformations.",
fields.String, missing=None, description="Classes to run transformations over the target data.",
)
joint_transforms = fields.List(
fields.String, missing=None, description="Classes to run transformations.",
fields.String, missing=None, description="Classes to run transformations over the input and target data.",
)
labels = fields.List(
fields.String, missing=None, description="Labels for the dataset",
)


class BaseModelSchema(Schema):
"""
Schema for configuring a base model.
:param num_classes: Number of classes for the model. Default is 2.
:type num_classes: int, optional
:param use_cuda: Flag indicating whether to use CUDA if available. Default is True.
:type use_cuda: bool, optional
:param metrics: Metrics to calculate during training and evaluation. Default is ['f1_score'].
:type metrics: List[str], optional
:param weights: Class weights to apply for the loss function. Default is None.
:type weights: List[float], optional
:param rank: Rank value for distributed data processing. Default is 0.
:type rank: int, optional
:param use_ddp: Flag indicating whether to turn on distributed data processing. Default is False.
:type use_ddp: bool, optional
"""
num_classes = fields.Int(missing=2, description="Number of classes", example=2)
use_cuda = fields.Bool(missing=True, description="Whether to use CUDA if possible")
metrics = fields.List(
Expand All @@ -49,6 +99,28 @@ class BaseModelSchema(Schema):


class BaseClassifierSchema(BaseModelSchema):
"""
Schema for configuring a base classifier.
:param learning_rate: Learning rate used in training. Default is 0.01.
:type learning_rate: float, optional
:param weight_decay: Weight decay used in training. Default is 0.0.
:type weight_decay: float, optional
:param pretrained: Flag indicating whether to use a pretrained model. Default is True.
:type pretrained: bool, optional
:param local_model_path: Local path of the pretrained model. Default is None.
:type local_model_path: str, optional
:param threshold: Prediction threshold if needed. Default is 0.5.
:type threshold: float, optional
:param freeze: Flag indicating whether to freeze all layers except for the classifier layer(s). Default is False.
:type freeze: bool, optional
"""

learning_rate = fields.Float(
missing=0.01, description="Learning rate used in training.", example=0.01
)
Expand All @@ -71,6 +143,14 @@ class BaseClassifierSchema(BaseModelSchema):


class BaseSegmentationClassifierSchema(BaseClassifierSchema):
"""
Schema for configuring a base segmentation classifier.
:param metrics: Classes of metrics you want to calculate during training and evaluation.
Default is ['iou', 'f1_score', 'accuracy'].
:type metrics: List[str], optional
"""

metrics = fields.List(
fields.String,
missing=["iou", "f1_score", "accuracy"],
Expand All @@ -80,6 +160,20 @@ class BaseSegmentationClassifierSchema(BaseClassifierSchema):


class BaseObjectDetectionSchema(BaseClassifierSchema):
"""
Schema for configuring a base object detection model.
:param metrics: Classes of metrics you want to calculate during training and evaluation.
Default is ['map'].
:type metrics: List[str], optional
:param step_size: Step size for the learning rate scheduler. Default is 15.
:type step_size: int, optional
:param gamma: Gamma (multiplier) for the learning rate scheduler. Default is 0.1.
:type gamma: float, optional
"""

metrics = fields.List(
fields.String,
missing=["map"],
Expand Down
33 changes: 33 additions & 0 deletions aitlas/datasets/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@


class MatDatasetSchema(BaseDatasetSchema):
"""
Schema for configuring a classification dataset given as mat file.
"""
mat_file = fields.String(
missing=None, description="mat file on disk", example="./data/dataset.mat",
)
Expand All @@ -21,6 +24,9 @@ class MatDatasetSchema(BaseDatasetSchema):


class NPZDatasetSchema(BaseDatasetSchema):
"""
Schema for configuring a classification dataset given as npz file.
"""
npz_file = fields.String(
missing=None, description="npz file on disk", example="./data/dataset.npz",
)
Expand All @@ -38,6 +44,9 @@ class NPZDatasetSchema(BaseDatasetSchema):


class ClassificationDatasetSchema(BaseDatasetSchema):
"""
Schema for configuring a classification dataset.
"""
data_dir = fields.String(
missing="/", description="Dataset path on disk", example="./data/BigEarthNet/"
)
Expand All @@ -47,6 +56,9 @@ class ClassificationDatasetSchema(BaseDatasetSchema):


class SegmentationDatasetSchema(BaseDatasetSchema):
"""
Schema for configuring a segmentation dataset.
"""
data_dir = fields.String(
missing="/", description="Dataset path on disk", example="./data/BigEarthNet/"
)
Expand All @@ -56,6 +68,9 @@ class SegmentationDatasetSchema(BaseDatasetSchema):


class ObjectDetectionPascalDatasetSchema(BaseDatasetSchema):
"""
Schema for configuring an object detection dataset given in PASCAL VOC format.
"""
imageset_file = fields.String(
missing="/",
description="File with the image ids in the set",
Expand All @@ -72,6 +87,9 @@ class ObjectDetectionPascalDatasetSchema(BaseDatasetSchema):


class ObjectDetectionCocoDatasetSchema(BaseDatasetSchema):
"""
Schema for configuring an object detection dataset given in COCO format.
"""
data_dir = fields.String(
missing="/", description="Dataset path on disk", example="./data/DIOR/"
)
Expand All @@ -86,6 +104,9 @@ class ObjectDetectionCocoDatasetSchema(BaseDatasetSchema):


class BigEarthNetSchema(BaseDatasetSchema):
"""
Schema for configuring the BigEarthNet dataset.
"""
csv_file = fields.String(
missing=None, description="CSV file on disk", example="./data/train.csv"
)
Expand Down Expand Up @@ -119,6 +140,9 @@ class BigEarthNetSchema(BaseDatasetSchema):


class SpaceNet6DatasetSchema(BaseDatasetSchema):
"""
Schema for configuring the SpaceNet6 dataset.
"""
orients = fields.String(
required=False,
example="path/to/data/train/AOI_11_Roterdam/SummaryData/SAR_orientations.csv",
Expand Down Expand Up @@ -211,6 +235,9 @@ class SpaceNet6DatasetSchema(BaseDatasetSchema):


class BreizhCropsSchema(BaseDatasetSchema):
"""
Schema for configuring the BreizhCrops dataset for crop type prediction.
"""
regions = fields.List(
fields.String,
required=True,
Expand Down Expand Up @@ -242,6 +269,9 @@ class BreizhCropsSchema(BaseDatasetSchema):


class CropsDatasetSchema(BaseDatasetSchema):
"""
Schema for configuring dataset for crop type prediction.
"""
csv_file_path = fields.String(
missing=None, description="CSV file on disk", example="./data/train.csv"
)
Expand All @@ -264,6 +294,9 @@ class CropsDatasetSchema(BaseDatasetSchema):


class So2SatDatasetSchema(BaseDatasetSchema):
"""
Schema for configuring the So2Sat dataset.
"""
h5_file = fields.String(
required=True, description="H5 file on disk", example="./data/train.h5"
)
21 changes: 21 additions & 0 deletions aitlas/models/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@


class TransformerModelSchema(BaseClassifierSchema):
"""
Schema for configuring a transformer model.
"""
input_dim = fields.Int(
required=True,
description="Number of bands (13 for L1C, 10 for L2A), 11 for eopatch slovenia",
Expand Down Expand Up @@ -34,6 +37,9 @@ class TransformerModelSchema(BaseClassifierSchema):


class InceptionTimeSchema(BaseClassifierSchema):
"""
Schema for configuring a InceptionTime model.
"""
input_dim = fields.Int(
required=True,
description="Number of bands (13 for L1C, 10 for L2A), 11 for eopatch slovenia",
Expand All @@ -54,6 +60,9 @@ class InceptionTimeSchema(BaseClassifierSchema):


class LSTMSchema(BaseClassifierSchema):
"""
Schema for configuring a LSTM model.
"""
input_dim = fields.Int(
required=True,
description="Number of bands (13 for L1C, 10 for L2A), 11 for eopatch slovenia",
Expand All @@ -80,6 +89,9 @@ class LSTMSchema(BaseClassifierSchema):


class MSResNetSchema(BaseClassifierSchema):
"""
Schema for configuring a MSResNet model.
"""
input_dim = fields.Int(
required=True,
description="Number of bands (13 for L1C, 10 for L2A), 11 for eopatch slovenia",
Expand All @@ -100,6 +112,9 @@ class MSResNetSchema(BaseClassifierSchema):


class TempCNNSchema(BaseClassifierSchema):
"""
Schema for configuring a TempCNN model.
"""
input_dim = fields.Int(
required=True,
description="Number of bands (13 for L1C, 10 for L2A), 11 for eopatch slovenia",
Expand All @@ -124,6 +139,9 @@ class TempCNNSchema(BaseClassifierSchema):


class StarRNNSchema(BaseClassifierSchema):
"""
Schema for configuring a StarRNN model.
"""
input_dim = fields.Int(
required=True,
description="Number of bands (13 for L1C, 10 for L2A), 11 for eopatch slovenia",
Expand All @@ -150,6 +168,9 @@ class StarRNNSchema(BaseClassifierSchema):


class OmniScaleCNNSchema(BaseClassifierSchema):
"""
Schema for configuring a OmniScaleCNN model.
"""
input_dim = fields.Int(
required=True,
description="Number of bands (13 for L1C, 10 for L2A), 11 for eopatch slovenia",
Expand Down
Loading

0 comments on commit bfc3bdb

Please sign in to comment.