-
Notifications
You must be signed in to change notification settings - Fork 0
/
Main.py
142 lines (113 loc) · 4.13 KB
/
Main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from datetime import datetime
import albumentations as A
import pytorch_lightning as pl
import torch
from albumentations.pytorch import ToTensorV2
from clearml import Task
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
from simple_parsing import ArgumentParser
from torch.utils.data import Subset
from config import config
from config.args import Args
from config.config import logger
from config.config import BASE_DIR, PROJECT_DIR, CONFIG_DIR, DATA_DIR, LOGS_DIR
from new_dataloader import CustomDataset # Import your updated dataloader
from model import Classifier
def printing_paths():
print(f"BASE_DIR: {BASE_DIR}")
print(f"PROJECT_DIR: {PROJECT_DIR}")
print(f"CONFIG_DIR: {CONFIG_DIR}")
print(f"DATA_DIR: {DATA_DIR}")
print(f"LOGS_DIR: {LOGS_DIR}")
# Preprocessing function
def get_transform(dataset):
resize = A.Resize(224, 224)
normalize = A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
to_tensor = ToTensorV2()
if dataset == "train":
return A.Compose([resize, normalize, to_tensor])
elif dataset == "val":
return A.Compose([resize, normalize, to_tensor])
# Dataset function
def prepare_dataset(bucket_name, train_prefix, val_prefix):
try:
trainset = CustomDataset(bucket_name=bucket_name, prefix=train_prefix, transforms=get_transform("train"))
valset = CustomDataset(bucket_name=bucket_name, prefix=val_prefix, transforms=get_transform("val"))
return trainset, valset
except Exception as e:
logger.error(f"Got an exception: {e}")
# Sanity check subset function
def create_subset(trainset, valset):
train_subset = Subset(trainset, range(100))
val_subset = Subset(valset, range(5))
return train_subset, val_subset
# Dataloaders function
def create_dataloaders(args, train_subset, val_subset):
from torch.utils.data import DataLoader
train_dataloader = DataLoader(
train_subset, batch_size=16, num_workers=4, shuffle=True
)
val_dataloader = DataLoader(
val_subset, batch_size=16, num_workers=4, shuffle=False
)
return train_dataloader, val_dataloader
def set_device():
if torch.cuda.is_available():
device = "gpu"
else:
device = "cpu"
return device
def main():
printing_paths()
device = set_device()
logger.info(f"Currently using {device} device...")
# Read args
logger.info("Reading arguments...")
parser = ArgumentParser()
parser.add_arguments(Args, dest="options")
args_namespace = parser.parse_args()
args = args_namespace.options
# Prepare dataset using S3 bucket and prefixes
logger.info("Preparing datasets...")
bucket_name = config.S3_DATA_BUCKET
train_prefix = config.S3_DATA_PREFIX + "train/"
val_prefix = config.S3_DATA_PREFIX + "val/"
trainset, valset = prepare_dataset(bucket_name=bucket_name, train_prefix=train_prefix, val_prefix=val_prefix)
logger.info(
f"""Total training images: {len(trainset)}
Total validation images: {len(valset)}"""
)
# Create dataloaders
train_dataloader, val_dataloader = create_dataloaders(args, trainset, valset)
# Initialize clearml task
logger.info("Initializing clearml task...")
task = Task.init(
project_name="streamlit/image-classification",
task_name=f"streamlit-image-classification-{datetime.now()}",
)
task.connect(args)
# Save top-K checkpoints based on "val_loss" metric
checkpoint_callback = ModelCheckpoint(
save_top_k=1,
monitor="loss/val_loss",
mode="min",
filename="streamlit-image-classification-{epoch:02d}-{val_loss:.2f}",
)
# Progress bar
progress_bar = RichProgressBar()
# Configure trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[checkpoint_callback, progress_bar],
default_root_dir=config.LOGS_DIR,
accelerator=device,
devices=1,
log_every_n_steps=1,
)
# Define classifier
classifier = Classifier()
# Fit the model
logger.info("Starting training...")
trainer.fit(classifier, train_dataloader, val_dataloader)
if __name__ == "__main__":
main()