Skip to content

Commit

Permalink
'data_docs_update'
Browse files Browse the repository at this point in the history
  • Loading branch information
simidjievskin committed Jul 3, 2023
1 parent 74aa348 commit 128854e
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
17 changes: 15 additions & 2 deletions aitlas/base/datasets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""Dataset base class.
This is the base class for all datasets. All datasets should subclass it.
"""

import torch
from torch.utils.data import Dataset

Expand All @@ -7,6 +12,15 @@


class BaseDataset(Dataset, Configurable):
"""This class represents a basic dataset for machine learning tasks. It is a
subclass of both :class:Dataset and :class:Configurable.
You can use it as a base class to define your own custom datasets.
:param Dataset: _description_
:type Dataset: _type_
:param Configurable: _description_
:type Configurable: _type_
"""

schema = BaseDatasetSchema
labels = None # need to put the labels here
Expand Down Expand Up @@ -38,7 +52,7 @@ def __init__(self, config):
self.joint_transform = self.load_transforms(self.config.joint_transforms)

def __getitem__(self, index):
""" Implement here what you want to return"""
"""Implement here what you want to return"""
raise NotImplementedError(
"Please implement the `__getittem__` method for your dataset"
)
Expand Down Expand Up @@ -108,4 +122,3 @@ def data_distribution_barchart(self):
def load_transforms(self, class_names):
"""Loads transformation classes and make a composition of them"""
return load_transforms(class_names, self.config)

7 changes: 6 additions & 1 deletion aitlas/base/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""Models base class.
This is the base class for all models. All models should subclass it.
"""
import collections
import copy
import logging
Expand Down Expand Up @@ -458,7 +463,7 @@ def predict_image(
# Convert results to dataframe for plotting
result = pd.DataFrame({"p": y_pred_probs[0]}, index=labels)
# Show the image
plt.rcParams.update({'font.size': 16})
plt.rcParams.update({"font.size": 16})
fig = plt.figure(figsize=(16, 7))
ax = plt.subplot(1, 2, 1)
ax.axis("off")
Expand Down

0 comments on commit 128854e

Please sign in to comment.