Skip to content

Commit

Permalink
saving..
Browse files Browse the repository at this point in the history
  • Loading branch information
sneccc committed May 24, 2024
1 parent 0280c58 commit f457c1d
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 26 deletions.
2 changes: 1 addition & 1 deletion predict_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def predict_score(root_folder, database_file, train_from, clip_models):
preprocessors.append(preprocess)

# Use the total dimension for the MLP model
mlp_model = MultiLayerPerceptron(total_dim)
mlp_model = MultiLayerPerceptron(input_size=total_dim)
model_name = f"{prefix}_linear_predictor_concatenated_{train_from}_mse.pth"
mlp_model.load_state_dict(torch.load(path / model_name))
mlp_model.to(device)
Expand Down
72 changes: 47 additions & 25 deletions train_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,28 @@
from torch.optim.lr_scheduler import ReduceLROnPlateau, LambdaLR
from torchmetrics import F1Score, Precision, Recall
from sklearn.model_selection import train_test_split
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor

torch.manual_seed(42)
np.random.seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class_weight_metrics = None


class MultiLayerPerceptron(pl.LightningModule):
def __init__(self, class_weight_metrics, input_size, hidden_units=(1024, 128, 64)):
def __init__(self, input_size, hidden_units=(1024, 256, 16)):
super().__init__()
# self.test_acc = Accuracy()

self.class_weight_metrics = class_weight_metrics.to(device)
global class_weight_metrics

if class_weight_metrics is not None:
self.class_weight_metrics = class_weight_metrics.to(
device) # This will ajust the weight for each label based on the data so its more balanced

self.loss_Function = nn.CrossEntropyLoss(weight=self.class_weight_metrics)


# Train Metrics
self.train_acc = Accuracy(num_classes=3, average='macro', multiclass=True)
self.train_f1_score = F1Score(num_classes=3, average='macro', multiclass=True)
Expand All @@ -44,10 +54,9 @@ def __init__(self, class_weight_metrics, input_size, hidden_units=(1024, 128, 64
all_layers = [nn.Flatten()]
for index, hidden_unit in enumerate(hidden_units):
all_layers.append(nn.Linear(input_size, hidden_unit)) # Linear layer
all_layers.append(nn.ReLU())
if index < len(hidden_units) - 1:
all_layers.append(nn.ReLU())
all_layers.append(nn.Dropout(0.2))

input_size = hidden_unit

all_layers.append(nn.Linear(hidden_units[-1], 3))
Expand Down Expand Up @@ -80,7 +89,7 @@ def validation_step(self, batch, batch_idx):
x = batch[0]
y = batch[1]
logits = self(x)
loss = nn.functional.cross_entropy(logits, y)
loss = self.loss_Function(logits, y)

preds = torch.argmax(logits, dim=1)

Expand All @@ -97,7 +106,7 @@ def training_step(self, batch, batch_idx):
x = batch[0]
y = batch[1]
logits = self(x)
loss = nn.functional.cross_entropy(logits, y, weight=self.class_weight_metrics)
loss = self.loss_Function(logits, y)
preds = torch.argmax(logits, dim=1)

self.log("train_loss", loss, prog_bar=False, on_step=False, on_epoch=True)
Expand Down Expand Up @@ -143,10 +152,10 @@ def on_validation_epoch_end(self):
# return loss

def configure_optimizers(self):
optimizer = "else"
optimizer = ("sgd")

if optimizer == "Adam":
optimizer = torch.optim.Adam(self.parameters(), lr=0.0001, weight_decay=1e-3)
optimizer = torch.optim.Adam(self.parameters(), lr=0.001, weight_decay=1e-4)
return optimizer
elif optimizer == "warmup":
warmup_epochs = 1000
Expand All @@ -165,35 +174,45 @@ def lr_lambda(current_epoch):
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
return [optimizer], [scheduler]
else:
# Initial learning rate
default_lr = 2e-3
weight_decay = 1e-6
epoch_lr_map = {1: 1e-4} # 1: 1e-4, 6750: 1e-5
current_lr_factor = 1.0

def lr_lambda(epoch):
nonlocal current_lr_factor
if epoch in epoch_lr_map:
current_lr_factor = epoch_lr_map[epoch] / default_lr
return current_lr_factor
# Learning rate adjustments for specific epochs
epoch_lr_map = {1: 1e-5, 100: 1e-5} # Adjust learning rate at epoch 1

# Define a lambda function for learning rate scheduling
lr_lambda = lambda epoch: epoch_lr_map.get(epoch, default_lr) / default_lr

# Initialize the optimizer with weight decay
optimizer = torch.optim.SGD(self.parameters(), lr=default_lr, momentum=0.9, weight_decay=1e-8)

optimizer = torch.optim.SGD(self.parameters(), lr=default_lr, momentum=0.9, weight_decay=weight_decay)
# Set up the learning rate scheduler
scheduler = LambdaLR(optimizer, lr_lambda)

return [optimizer], [scheduler]


def start_training(root_folder, database_file, train_from, clip_models, val_percentage=0.25, epochs=5000,
batch_size=1000):
train_dataloader, val_dataloader, model_name, class_weight_metrics = setup_dataset(root_folder=root_folder,
database_file=database_file,
train_from=train_from)
train_dataloader, val_dataloader, model_name, class_weight = setup_dataset(root_folder=root_folder,
database_file=database_file,
train_from=train_from)
input_size = get_total_dim(clip_models)
print("input size", input_size) # 1152

net = MultiLayerPerceptron(input_size=input_size, class_weight_metrics=class_weight_metrics)
global class_weight_metrics
class_weight_metrics = class_weight
net = MultiLayerPerceptron(input_size=input_size)
callbacks = [
ModelCheckpoint(save_top_k=1, mode='max', monitor="val_acc"),
LearningRateMonitor(logging_interval='epoch')
LearningRateMonitor(logging_interval='epoch'),
EarlyStopping(
monitor='val_loss',
min_delta=0.00,
patience=20,
verbose=True,
mode='min'
)
] # save top 1 model
logger = TensorBoardLogger('tb_logs', name="my_logger", log_graph=True)
# lr_monitor = LearningRateMonitor(logging_interval='epoch')
Expand Down Expand Up @@ -253,17 +272,20 @@ def setup_dataset(root_folder, database_file, train_from):

# Calculate class distribution for weights
class_counts = np.bincount(train_tensor_y.numpy())
print("🐍 Class Frequency: ", class_counts)
total_samples = len(train_tensor_y)
class_weights = total_samples / (len(class_counts) * class_counts)
class_weight_metrics = torch.tensor(class_weights, dtype=torch.float)
num_classes = len(class_counts)
class_weights = total_samples / (num_classes * class_counts)
class_weight_tensor = torch.tensor(class_weights, dtype=torch.float)
print("🐍 Class weights: ", class_weight_tensor)

train_dataset = TensorDataset(train_tensor_x, train_tensor_y)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn)

val_dataset = TensorDataset(val_tensor_x, val_tensor_y)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate_fn)

return train_loader, val_loader, model_name, class_weight_metrics
return train_loader, val_loader, model_name, class_weight_tensor


def get_total_dim(clip_models):
Expand Down

0 comments on commit f457c1d

Please sign in to comment.