-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
113 lines (90 loc) · 3.72 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import torch_geometric
from torch_geometric.data import HeteroData
import numpy as np
from tqdm import tqdm
import os
import pickle
import yaml
import random
import warnings
import argparse
import traceback
import logging
from utils.misc import load_config
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str,
default='./config/train.yml')
parser.add_argument('--data', type=str,
default='./dataset/crossdocked_vina10')
parser.add_argument('--device', type=str,
default='cpu')
parser.add_argument('--logdir', type=str,
default='./logs')
parser.add_argument('--outdir', type=str,
default='./output')
args = parser.parse_args()
# Logging
log_dir = get_new_log_dir(args.logdir, prefix=config_name)
ckpt_dir = os.path.join(log_dir, 'checkpoints')
os.makedirs(ckpt_dir, exist_ok=True)
logger = get_logger('train', log_dir)
writer = torch.utils.tensorboard.SummaryWriter(log_dir)
logger.info(args)
logger.info(config)
shutil.copyfile(args.config, os.path.join(log_dir, os.path.basename(args.config)))
shutil.copytree('./models', os.path.join(log_dir, 'models'))
# Load config
outDirExists = os.path.isfile(args.config)
if not outDirExists:
raise FileNotFoundError("Configuration YML (config.yml) file does not exist or is not specified.")
else:
logger.info("Reading configuration YML file...")
config = load_config(args.config)
seed_all(config.featuriser.seed)
split_dict = torch.load(config.dataset.split)
logger.info(f"Found {len(split_dict['train'])} samples in the Crossdock dataset for training.")
if config.train.use_apex:
from apex import amp
print("\n")
outDirExists = os.path.exists(args.outdir)
if not outDirExists:
os.makedirs(args.outdir)
logger.info("Output directory {args.outdir} is created.")
# Model
logger.info('Building model...')
if config.model.vn == 'singa':
model = MaskFillModelVN(
config.model,
num_classes = contrastive_sampler.num_elements,
num_bond_types = edge_sampler.num_bond_types,
protein_atom_feature_dim = protein_featurizer.feature_dim,
ligand_atom_feature_dim = ligand_featurizer.feature_dim,
).to(args.device)
print('Num of parameters is', np.sum([p.numel() for p in model.parameters()]))
# Optimizer and scheduler
optimizer = get_optimizer(config.train.optimizer, model)
scheduler = get_scheduler(config.train.scheduler, optimizer)
if config.train.use_apex:
model, optimizer = amp.initialize(model, optimizer, opt_level='O1'
try:
model.train()
for it in range(1, config.train.max_iters+1):
try:
train(it)
except RuntimeError as e:
logger.error('Runtime Error ' + str(e))
if it % config.train.val_freq == 0 or it == config.train.max_iters:
validate(it)
ckpt_path = os.path.join(ckpt_dir, '%d.pt' % it)
torch.save({
'config': config,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'iteration': it,
}, ckpt_path)
model.train()
except KeyboardInterrupt:
logger.info('Terminating...')