Skip to content

Commit

Permalink
Merge pull request #9 from dtrizna/main-candidate
Browse files Browse the repository at this point in the history
Updates on lightning wrappers, pretraining, multiclass logic, etc.
  • Loading branch information
dtrizna authored Feb 28, 2024
2 parents 5476b4f + 765c044 commit b9a0d38
Show file tree
Hide file tree
Showing 49 changed files with 1,537 additions and 188 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ build/*
.markdownlint.json

# research files
evaluation/*
z_out*

# Data
Expand Down
14 changes: 10 additions & 4 deletions nebula/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,11 @@ def __init__(
"dHidden": 256, # dimension of the feedforward network model in nn.TransformerEncoder
"nLayers": 2, # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
"numClasses": 1, # binary classification
"hiddenNeurons": [64], # classifier head depth
"classifier_head": [64], # classifier head depth
"layerNorm": False,
"dropout": 0.3,
"norm_first": True,
"pooling": None
"pooling": "flatten"
}
self.model = TransformerEncoderChunks(**torch_model_config)

Expand Down Expand Up @@ -181,7 +181,13 @@ def __init__(self,
if n_output_classes is not None:
self.n_output_classes = n_output_classes
else:
self.n_output_classes = [x for x in self.model.children() if isinstance(x, Linear)][-1].out_features
layers = [x for x in self.model.children() if isinstance(x, nn.Linear) or isinstance(x, nn.Sequential)]
if isinstance(layers[-1], nn.Sequential):
self.n_output_classes = layers[-1][-1].out_features
elif isinstance(layers[-1], nn.Linear):
self.n_output_classes = layers[-1].out_features
else:
raise ValueError("An error occurred during identification of the number of class.")

# lr scheduling setup, for visulaizations see:
# https://towardsdatascience.com/a-visual-guide-to-learning-rate-schedulers-in-pytorch-24bbb262c863
Expand Down Expand Up @@ -609,7 +615,7 @@ def __init__(self,
)
self.dynamicModel = Cnn1DLinearLM(
vocabSize=len(self.tokenizer.vocab),
hiddenNeurons=[512, representationSize],
classifier_head=[512, representationSize],
dropout=dropout,
)

Expand Down
51 changes: 36 additions & 15 deletions nebula/lit_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
import numpy as np
from tqdm import tqdm
from time import time
from typing import List, Optional
from typing import List, Optional, Union
from collections import OrderedDict
from copy import deepcopy
from torch import load
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from .lit_utils import LitTrainerWrapper, PyTorchLightningModelLM
from .misc import clear_cuda_cache


class LanguageModelTrainer(LitTrainerWrapper):
Expand Down Expand Up @@ -74,7 +76,7 @@ def pretrain(self, x_unlabeled: np.ndarray, epochs: int = None):
# need to stop training in between to peform extra logic: saving or remasking
# TODO: this doesn't work with 'scheduler', since L.Trainer
# calls configure_optimizers() with every .fit()
# need to write custom configure_optimizers() for LM model ???
# need to rewrite on_epoch_end() in Trainer or something
# for now avoiding using remask_every_n_epochs and dump_model_every_epoch w/ scheduler
loop_length = self.rebuild_dataloader_every_n_epochs
total_loops = self.pretrain_epochs // loop_length
Expand Down Expand Up @@ -245,8 +247,7 @@ def create_lm_dataloader(self, x_unlabeled: np.ndarray, shuffle: bool = True) ->
class SelfSupervisedLearningEvalFramework:
def __init__(
self,
# pretrainer: Union[MaskedLanguageModelTrainer, AutoRegressiveModelTrainer],
pretrainer: MaskedLanguageModelTrainer,
pretrainer: Union[MaskedLanguageModelTrainer, AutoRegressiveModelTrainer],
downstream_trainer: LitTrainerWrapper,
training_types: List[str] = ['pretrained', 'non_pretrained', 'full_data'],
# eval details
Expand Down Expand Up @@ -301,7 +302,25 @@ def _dump_data_splits(self):
print(f"[!] Saved dataset splits to {split_data_file}")


def _train_downstream_model(self, training_type, pretrained_weights=None):
@staticmethod
def _transfer_pretrained_weights(
pretrained_state_dict: OrderedDict,
downstream_state_dict: OrderedDict
) -> OrderedDict:
"""
Transfer pretrained weights from a pretrained state dict to a downstream dict.
"""

new_state_dict = deepcopy(downstream_state_dict)
for name in downstream_state_dict:
if name in pretrained_state_dict:
new_state_dict[name] = deepcopy(pretrained_state_dict[name])

return new_state_dict


def _train_downstream_model(self, training_type: str) -> None:

self.downstream_trainer.log_folder = self.init_downstream_log_folder + "_" + training_type + "_" + str(self.timestamp)
final_model_file = os.path.join(self.downstream_trainer.log_folder, f"{training_type}_final.torch")
if os.path.exists(final_model_file):
Expand All @@ -314,7 +333,7 @@ def _train_downstream_model(self, training_type, pretrained_weights=None):
self.downstream_trainer.name = training_type

if training_type == "pretrained":
self.downstream_trainer.pytorch_model.load_state_dict(pretrained_weights)
self.downstream_trainer.pytorch_model.load_state_dict(self.pretrained_weights)
else:
self.downstream_trainer.pytorch_model.load_state_dict(self.init_downstream_model_weights)

Expand All @@ -329,6 +348,7 @@ def _train_downstream_model(self, training_type, pretrained_weights=None):
self.downstream_trainer.setup_lit_model()
self.downstream_trainer.train_lit_model(self.train_loader, self.val_loader)
self.downstream_trainer.save_torch_model(final_model_file)
clear_cuda_cache()


def run_one_split(
Expand Down Expand Up @@ -370,18 +390,19 @@ def run_one_split(
if os.path.exists(self.pretrained_model_path):
print(f"[!] Loading pretrained model from: '{self.pretrained_model_path}'")
pretrained_model = load(self.pretrained_model_path)
pretrained_weights = pretrained_model.state_dict()
pretrained_model_state_dict = pretrained_model.state_dict()
else:
print("[!] Pre-training model...")
print(f"[!] Pre-training '{self.pretrainer.name}' model...")
# reset model weights -- needed for multiple splits
self.pretrainer.pytorch_model.load_state_dict(self.init_pretrain_model_weights)
self.pretrainer.pretrain(self.unlabeled_data)
pretrained_weights = deepcopy(self.pretrainer.pytorch_model.state_dict())
pretrained_model_state_dict = deepcopy(self.pretrainer.pytorch_model.state_dict())
clear_cuda_cache()

# remove pre-train head
to_remove = [k for k in pretrained_weights.keys() if k.startswith('pretrain_layers')]
for k in to_remove:
del pretrained_weights[k]
self.pretrained_weights = self._transfer_pretrained_weights(
pretrained_model_state_dict,
self.init_downstream_model_weights
)

self.train_loader = self.downstream_trainer.create_dataloader(self.labeled_x, self.labeled_y, shuffle=True)
if "full_data" in self.training_types:
Expand All @@ -394,7 +415,7 @@ def run_one_split(

for training_type in self.training_types:
print(f"[!] Fine-tuning of '{training_type}' model on downstream task...")
self._train_downstream_model(training_type, pretrained_weights)
self._train_downstream_model(training_type)


def run_splits(self, x_train, y_train, x_val, y_val, previous_run_idxs: Optional[List] = None):
Expand All @@ -407,5 +428,5 @@ def run_splits(self, x_train, y_train, x_val, y_val, previous_run_idxs: Optional
self.random_state += i # to get different splits
self.pretrainer.random_state = self.random_state
self.downstream_trainer.random_state = self.random_state
print(f'[!] Running pre-training split {i+1}/{self.n_splits}')
print(f"[!] Running '{self.pretrainer.name}' pre-training split {i+1}/{self.n_splits}")
self.run_one_split(x_train, y_train, x_val, y_val,)
97 changes: 64 additions & 33 deletions nebula/lit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from shutil import copyfile
from typing import Union, Any, Callable, Optional

import warnings
warnings.filterwarnings("ignore", category=UserWarning)

from sklearn.metrics import roc_curve

import lightning as L
Expand All @@ -13,7 +16,7 @@
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger

import torch
from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss
from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss, Linear, Sequential
from torch.utils.data import DataLoader
import torchmetrics

Expand All @@ -37,6 +40,7 @@ def __init__(
scheduler_step_budget: Union[None, int] = None,
# NOTE: scheduler_step_budget = epochs * len(train_loader)
loss: Callable = BCEWithLogitsLoss(),
out_classes = 1
):
super().__init__()

Expand All @@ -52,28 +56,34 @@ def __init__(
print(f"[!] Scheduler: {scheduler} | Scheduler step budget: {scheduler_step_budget}")
self.scheduler = scheduler
self.scheduler_step_budget = scheduler_step_budget

self.train_acc = torchmetrics.Accuracy(task='binary')
self.train_f1 = torchmetrics.F1Score(task='binary', average='macro')
self.train_auc = torchmetrics.AUROC(task='binary')

task = 'multiclass' if isinstance(loss, CrossEntropyLoss) else 'binary'
if task == 'multiclass' and out_classes == 1:
layers = [x for x in self.model.children() if isinstance(x, Linear) or isinstance(x, Sequential)]
out_classes = layers[-1].out_features if isinstance(layers[-1], Linear) else layers[-1][-1].out_features
self.out_classes = out_classes

self.train_acc = torchmetrics.Accuracy(task=task, num_classes=out_classes)
self.train_f1 = torchmetrics.F1Score(task=task, num_classes=out_classes, average='macro')
self.train_auc = torchmetrics.AUROC(task=task, num_classes=out_classes)
self.train_tpr = self.get_tpr_at_fpr
self.train_recall = torchmetrics.Recall(task='binary')
self.train_precision = torchmetrics.Precision(task='binary')
self.train_recall = torchmetrics.Recall(task=task, num_classes=out_classes)
self.train_precision = torchmetrics.Precision(task=task, num_classes=out_classes)

self.val_acc = torchmetrics.Accuracy(task='binary')
self.val_f1 = torchmetrics.F1Score(task='binary', average='macro')
self.val_auc = torchmetrics.AUROC(task='binary')
self.val_acc = torchmetrics.Accuracy(task=task, num_classes=out_classes)
self.val_f1 = torchmetrics.F1Score(task=task, num_classes=out_classes, average='macro')
self.val_auc = torchmetrics.AUROC(task=task, num_classes=out_classes)
self.val_tpr = self.get_tpr_at_fpr
self.val_recall = torchmetrics.Recall(task='binary')
self.val_precision = torchmetrics.Precision(task='binary')
self.val_recall = torchmetrics.Recall(task=task, num_classes=out_classes)
self.val_precision = torchmetrics.Precision(task=task, num_classes=out_classes)

# NOTE: not using .test(), don't want to have these rudimentary metrics to drag over
# self.test_acc = torchmetrics.Accuracy(task='binary')
# self.test_f1 = torchmetrics.F1Score(task='binary', average='macro')
# self.test_auc = torchmetrics.AUROC(task='binary')
# self.test_acc = torchmetrics.Accuracy(task=task, num_classes=out_classes)
# self.test_f1 = torchmetrics.F1Score(task=task, num_classes=out_classes, average='macro')
# self.test_auc = torchmetrics.AUROC(task=task, num_classes=out_classes)
# self.test_tpr = self.get_tpr_at_fpr
# self.test_recall = torchmetrics.Recall(task='binary')
# self.test_precision = torchmetrics.Precision(task='binary')
# self.test_recall = torchmetrics.Recall(task=task, num_classes=out_classes)
# self.test_precision = torchmetrics.Precision(task=task, num_classes=out_classes)

# self.save_hyperparameters(ignore=["model"])

Expand Down Expand Up @@ -123,9 +133,13 @@ def get_tpr_at_fpr(
fpr, tpr, thresholds = roc_curve(true_labels, predicted_probs)
except ValueError:
# when multi-label 'ValueError: multilabel-indicator format is not supported'
return (torch.nan, torch.nan) if return_thresholds else torch.nan
# return (torch.nan, torch.nan) if return_thresholds else torch.nan
# avoid using nan since throws WARNING NaN or Inf found in input tensor.
return (0, 0) if return_thresholds else 0
if all(np.isnan(fpr)):
return (torch.nan, torch.nan) if return_thresholds else torch.nan
# return (torch.nan, torch.nan) if return_thresholds else torch.nan
# avoid using nan since throws WARNING NaN or Inf found in input tensor.
return (0, 0) if return_thresholds else 0
else:
tpr_at_fpr = tpr[fpr <= fprNeeded][-1]
threshold_at_fpr = thresholds[fpr <= fprNeeded][-1]
Expand All @@ -141,13 +155,15 @@ def _shared_step(self, batch: torch.Tensor):

if y.ndim == 2 and logits.ndim == 3: # e.g. autoregressive pre-training
logits, y = logits.view(-1, logits.size(-1)), y.view(-1)
elif y.ndim == 1: # binary classification: (batch_size,) => (batch_size, 1)
elif logits.ndim == 2 and self.out_classes != 1: # multiclass, logits.shape: (batch_size, num_classes)
y = y.squeeze().to(torch.int64)
elif y.ndim == 1: # binary classification: (batch_size, ) => (batch_size, 1)
y = y.unsqueeze(-1)

loss = self.loss(logits, y)
# NOTE: by returning only loss here we avoid memory leaks
self.logits = logits.detach().cpu()
self.y = y.detach().cpu()
self.logits = logits.detach()#.cpu()
self.y = y.detach()#.cpu()
return loss

def training_step(self, batch: torch.Tensor, batch_idx):
Expand Down Expand Up @@ -198,7 +214,7 @@ class PyTorchLightningModel(PyTorchLightningModelBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def validation_step(self, batch: (torch.Tensor, torch.Tensor), batch_idx):
def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx):
# NOTE: keep batch_idx -- lightning needs it
# loss, y, logits = self._shared_step(batch)
loss = self._shared_step(batch)
Expand Down Expand Up @@ -249,7 +265,7 @@ def __init__(
lit_sanity_steps: int = 1,
monitor_metric: str = "val_tpr",
monitor_mode: str = "max",
early_stop_patience: Union[None, int] = 5,
early_stop_patience: Union[None, int] = None,
early_stop_min_delta: float = 0.0001,
# efficient training strategies
scheduler: Union[None, str] = None,
Expand All @@ -261,9 +277,11 @@ def __init__(
# data config
batch_size: int = 1024,
dataloader_workers: int = 4,
loss: Callable = BCEWithLogitsLoss(),
out_classes: int = 1,
random_state: int = 42,
verbose: bool = False,
skip_trainer_init: Optional[bool] = False
skip_trainer_init: Optional[bool] = False,
):
self.pytorch_model = pytorch_model
self.lit_model = None
Expand Down Expand Up @@ -293,6 +311,8 @@ def __init__(

self.batch_size = batch_size
self.dataloader_workers = dataloader_workers
self.loss = loss
self.out_classes = out_classes

self.verbose = verbose
self.random_state = random_state
Expand Down Expand Up @@ -344,7 +364,7 @@ def setup_trainer(self):
callbacks = self.setup_callbacks()

if self.log_folder is None:
self.log_folder = f"./out_{self.name}_{int(time())}"
self.log_folder = f"./out_{self.name}"
try:
os.makedirs(self.log_folder, exist_ok=True)
except ValueError as ex:
Expand Down Expand Up @@ -388,7 +408,7 @@ def load_torch_model(self, model_file: str = None):
assert model_file is not None, "Please provide a model file"
self.pytorch_model = torch.load(model_file)
# NOTE: you have to reset self.lit_model after this
# if lit_model is already initialized, then load state dict directly:
# if lit_model is already initialized, then load state dict directly:
# self.lit_model.model.load_state_dict(state_dict)


Expand All @@ -400,6 +420,8 @@ def setup_lit_model(self):
learning_rate=self.learning_rate,
scheduler=self.scheduler,
scheduler_step_budget=self.scheduler_budget,
loss=self.loss,
out_classes=self.out_classes
)


Expand Down Expand Up @@ -427,19 +449,27 @@ def train_lit_model(
def predict_lit_model(
self,
loader: DataLoader,
decision_threshold: int = 0.5,
decision_threshold: int = 0.5,
return_logits: bool = False,
dump_logits: Union[bool, str] = False
) -> np.ndarray:
assert self.lit_model is not None,\
"[-] lightning_model isn't instantiated: either .train_lit_model() or .load_lit_model()"
"""Get scores out of a loader."""
y_pred_logits = self.trainer.predict(model=self.lit_model, dataloaders=loader)
y_pred = torch.sigmoid(torch.cat(y_pred_logits, dim=0)).numpy()
y_pred = np.array([1 if x > decision_threshold else 0 for x in y_pred])
y_pred_logits = torch.cat(y_pred_logits, dim=0)
if dump_logits:
assert isinstance(dump_logits, str), "Please provide a path to dump logits: dump_logits='path/to/logits.pkl'"
pickle.dump(y_pred_logits, open(dump_logits, "wb"))
return y_pred
if return_logits:
return y_pred_logits

y_pred = torch.sigmoid(y_pred_logits).numpy()
try:
y_pred = np.array([1 if x > decision_threshold else 0 for x in y_pred])
except ValueError: # multiclass
y_pred = np.argmax(y_pred, axis=1)
return y_pred


def save_lit_model(self, model_file: str = None, how="best"):
Expand Down Expand Up @@ -514,7 +544,8 @@ def calculate_scheduler_step_budget(
accumulate_grad_batches = 1 if self.accumulate_grad_batches is None else self.accumulate_grad_batches
total_batches = 0
if max_epochs is not None:
total_batches = int(np.ceil(max_epochs * len(self.train_loader) / accumulate_grad_batches))
steps_per_epoch = np.ceil(len(self.train_loader) / accumulate_grad_batches)
total_batches = int(max_epochs * steps_per_epoch)
if max_time is not None:
# TODO: Implement logic for max_time if needed
raise NotImplementedError("calculate_scheduler_step_budget for max_time is not implemented yet")
Expand Down
Loading

0 comments on commit b9a0d38

Please sign in to comment.