-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtrain.py
87 lines (71 loc) · 2.29 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""Training script for generative models
Author(s): Tristan Stevens
"""
import argparse
from pathlib import Path
import matplotlib
import wandb
from keras.callbacks import ReduceLROnPlateau
from wandb.keras import WandbCallback
from datasets import get_dataset
from generators.models import get_model
from utils.callbacks import EvalDataset, Monitor
from utils.checkpoints import ModelCheckpoint
from utils.git_info import get_git_summary
from utils.gpu_config import set_gpu_usage
from utils.utils import check_model_library, random_augmentation, set_random_seed
matplotlib.use("Agg")
parser = argparse.ArgumentParser()
parser.add_argument(
"-c",
"--config",
default="configs/training/score_celeba.yaml",
type=str,
help="relative path to config file",
)
parser.add_argument("-e", "--run_eagerly", default=False, type=bool, help="run eagerly")
args = parser.parse_args()
# handle absolute / relative paths
if not Path(args.config).exists():
path = Path(args.config).with_suffix(".yaml")
args.config = f"./configs/training/{path}"
print(f"Using config file: {args.config}")
run = wandb.init(
project="deep_generative",
group="generative",
config=args.config,
job_type="train",
allow_val_change=True,
)
print(f"wandb: {run.job_type} run {run.name}\n")
config = wandb.config
config.update({"log_dir": run.dir})
set_gpu_usage(config.get("device"))
set_random_seed(config.seed)
config.update({"git": get_git_summary()})
dataset, test_dataset = get_dataset(config)
if config.get("augmentation"):
print("Training with augmented dataset")
dataset = dataset.map(random_augmentation(config))
model = get_model(config, run_eagerly=args.run_eagerly, plot_summary=True)
model_library = check_model_library(model)
print(f"Monitoring loss: {model.monitor_loss}")
callbacks = [
EvalDataset(model=model, dataset=dataset, config=config),
Monitor(model=model, config=config),
ModelCheckpoint(model=model, config=config),
]
if model_library == "tensorflow":
callbacks += [
WandbCallback(),
ReduceLROnPlateau(monitor=model.monitor_loss, factor=0.3, verbose=1),
]
if model_library == "pytorch":
callbacks += []
model.fit(
dataset,
epochs=config.epochs,
callbacks=callbacks,
steps_per_epoch=config.get("steps_per_epoch"),
)
run.finish()