From 128854e1eee7582234dbb14d2fb382a8e58b327d Mon Sep 17 00:00:00 2001 From: simidjievskin Date: Mon, 3 Jul 2023 13:10:48 +0100 Subject: [PATCH] 'data_docs_update' --- aitlas/base/datasets.py | 17 +++++++++++++++-- aitlas/base/models.py | 7 ++++++- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/aitlas/base/datasets.py b/aitlas/base/datasets.py index 241fc0f..5711b74 100644 --- a/aitlas/base/datasets.py +++ b/aitlas/base/datasets.py @@ -1,3 +1,8 @@ +"""Dataset base class. + +This is the base class for all datasets. All datasets should subclass it. +""" + import torch from torch.utils.data import Dataset @@ -7,6 +12,15 @@ class BaseDataset(Dataset, Configurable): + """This class represents a basic dataset for machine learning tasks. It is a + subclass of both :class:Dataset and :class:Configurable. + You can use it as a base class to define your own custom datasets. + + :param Dataset: _description_ + :type Dataset: _type_ + :param Configurable: _description_ + :type Configurable: _type_ + """ schema = BaseDatasetSchema labels = None # need to put the labels here @@ -38,7 +52,7 @@ def __init__(self, config): self.joint_transform = self.load_transforms(self.config.joint_transforms) def __getitem__(self, index): - """ Implement here what you want to return""" + """Implement here what you want to return""" raise NotImplementedError( "Please implement the `__getittem__` method for your dataset" ) @@ -108,4 +122,3 @@ def data_distribution_barchart(self): def load_transforms(self, class_names): """Loads transformation classes and make a composition of them""" return load_transforms(class_names, self.config) - diff --git a/aitlas/base/models.py b/aitlas/base/models.py index 4660817..1feefee 100644 --- a/aitlas/base/models.py +++ b/aitlas/base/models.py @@ -1,3 +1,8 @@ +"""Models base class. + +This is the base class for all models. All models should subclass it. + +""" import collections import copy import logging @@ -458,7 +463,7 @@ def predict_image( # Convert results to dataframe for plotting result = pd.DataFrame({"p": y_pred_probs[0]}, index=labels) # Show the image - plt.rcParams.update({'font.size': 16}) + plt.rcParams.update({"font.size": 16}) fig = plt.figure(figsize=(16, 7)) ax = plt.subplot(1, 2, 1) ax.axis("off")