Skip to content

Commit

Permalink
Fix and speed up evaluation (theislab#10)
Browse files Browse the repository at this point in the history
* Fix early stopping

* Adjust executable to early stopping

* Fix nans in r2 eval and speed up disentanglement eval
  • Loading branch information
MxMstrmn authored Aug 9, 2021
1 parent 33ea48a commit bb65b05
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 27 deletions.
6 changes: 6 additions & 0 deletions compert/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,12 @@ def early_stopping(self, score):
if self.num_drugs > 0:
self.scheduler_dosers.step()

if score > self.best_score:
self.best_score = score
self.patience_trials = 0
else:
self.patience_trials +=1

return self.patience_trials > self.patience

def update(self, genes, drugs, covariates):
Expand Down
33 changes: 15 additions & 18 deletions compert/seml_sweep_icb.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import seml
import numpy as np
import pandas as pd
from compert.train import custom_collate, evaluate
from compert.data import load_dataset_splits
from compert.model import ComPert
Expand Down Expand Up @@ -178,16 +179,21 @@ def train(
stop = ellapsed_minutes > max_minutes or (epoch == num_epochs - 1)

if (epoch % checkpoint_freq) == 0 or stop:
evaluation_stats = {}
if not ignore_evaluation:
evaluation_stats = evaluate(self.autoencoder, self.datasets)
score = np.mean(evaluation_stats["test"])
stop = stop or self.autoencoder.early_stopping(score)

if stop:
evaluation_stats = evaluate(
self.autoencoder, self.datasets, disentangle=True
)
for key, val in evaluation_stats.items():
if not (key in self.autoencoder.history.keys()):
self.autoencoder.history[key] = []
self.autoencoder.history[key].append(val)
self.autoencoder.history["stats_epoch"].append(epoch)
else:
evaluation_stats = {}

pjson(
{
"epoch": epoch,
Expand All @@ -196,41 +202,32 @@ def train(
"ellapsed_minutes": ellapsed_minutes,
}
)

if save_checkpoints:
if save_dir is None or not os.path.exists(save_dir):
print(os.path.exists(save_dir))
print(not os.path.exists(save_dir))
raise ValueError(
"Please provide a valid directory path in the 'save_dir' argument."
)
file_name = "model_seed={}_epoch={}.pt".format(self.seed, epoch)
torch.save(
(
self.autoencoder.state_dict(),
self.autoencoder.hparams,
self.autoencoder.history,
),
os.path.join(
save_dir,
"model_seed={}_epoch={}.pt".format(self.seed, epoch),
),
os.path.join(save_dir, file_name),
)

pjson(
{
"model_saved": "model_seed={}_epoch={}.pt\n".format(
self.seed, epoch
)
}
)
if not ignore_evaluation:
stop = stop or self.autoencoder.early_stopping(
np.mean(evaluation_stats["test"])
)
pjson({"model_saved": file_name})

if stop:
pjson({"early_stop": epoch})
break

results = self.autoencoder.history
# results = pd.DataFrame.from_dict(results) # not same length!
results["total_epochs"] = epoch
return results

Expand Down
38 changes: 29 additions & 9 deletions compert/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,14 @@ def evaluate_disentanglement(autoencoder, dataset, nonlinear=False):
if nonlinear:
clf = KNeighborsClassifier(n_neighbors=int(np.sqrt(len(latent_basal))))
else:
clf = LogisticRegression(solver="liblinear", multi_class="auto", max_iter=10000)
clf = LogisticRegression(
solver="saga",
multi_class="multinomial",
max_iter=3000,
# n_jobs=-1,
# verbose=2,
tol=1e-2,
)

pert_scores, cov_scores = 0, []

Expand Down Expand Up @@ -120,11 +127,22 @@ def evaluate_r2(autoencoder, dataset, genes_control):
yp_m = mean_predict.mean(0)
yp_v = var_predict.mean(0)

mean_score.append(r2_score(yt_m, yp_m))
var_score.append(r2_score(yt_v, yp_v))
yp_m = torch.clamp(yp_m, -3e12, 3e12)
yp_v = torch.clamp(yp_v, -3e12, 3e12)

r2_m = -1 if torch.isnan(yp_m).any() else r2_score(yt_m, yp_m)
r2_v = -1 if torch.isnan(yp_v).any() else r2_score(yt_v, yp_v)
r2_m_de = (
-1 if torch.isnan(yp_m).any() else r2_score(yt_m[de_idx], yp_m[de_idx])
)
r2_v_de = (
-1 if torch.isnan(yp_v).any() else r2_score(yt_v[de_idx], yp_v[de_idx])
)

mean_score_de.append(r2_score(yt_m[de_idx], yp_m[de_idx]))
var_score_de.append(r2_score(yt_v[de_idx], yp_v[de_idx]))
mean_score.append(r2_m)
var_score.append(r2_v)
mean_score_de.append(r2_m_de)
var_score_de.append(r2_v_de)
return [
np.mean(s) if len(s) else -1
for s in [mean_score, mean_score_de, var_score, var_score_de]
Expand Down Expand Up @@ -173,7 +191,7 @@ def evaluate(autoencoder, datasets, disentangle=False):
}
autoencoder.train()
ellapsed_minutes = (time.time() - start_time) / 60
print(f"Took {ellapsed_minutes:.1f} min for evaluation.")
print(f"\nTook {ellapsed_minutes:.1f} min for evaluation.\n")
return evaluation_stats


Expand Down Expand Up @@ -263,7 +281,7 @@ def train_compert(args, return_model=False, ignore_evaluation=True):

print(f"\nCWD: {os.getcwd()}")
print(f"Save dir: {args['save_dir']}")
print(f"Got valid path for 'save_dir'?: {os.path.exists(args['save_dir'])}\n")
print(f"Got valid path for 'save_dir': {os.path.exists(args['save_dir'])}\n")
start_time = time.time()
for epoch in tqdm(range(args["max_epochs"])):
epoch_training_stats = defaultdict(float)
Expand Down Expand Up @@ -294,7 +312,9 @@ def train_compert(args, return_model=False, ignore_evaluation=True):

if (epoch % args["checkpoint_freq"]) == 0 or stop:
if not ignore_evaluation:
evaluation_stats = evaluate(autoencoder, datasets)
evaluation_stats = evaluate(
autoencoder, datasets, disentangle=True if stop else False
)
for key, val in evaluation_stats.items():
if not (key in autoencoder.history.keys()):
autoencoder.history[key] = []
Expand Down Expand Up @@ -395,7 +415,7 @@ def parse_arguments():
"mol_featurizer": "canonical",
"checkpoint_freq": 15, # checkoint frequencty to save intermediate results
"hparams": "", # autoencoder architecture
"max_epochs": 20, # maximum epochs for training
"max_epochs": 5, # maximum epochs for training
"max_minutes": max_minutes, # maximum computation time
"patience": 20, # patience for early stopping
"loss_ae": "gauss", # loss (currently only gaussian loss is supported)
Expand Down

0 comments on commit bb65b05

Please sign in to comment.