forked from molecularsets/moses
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
72 lines (57 loc) · 1.94 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import argparse
import os
import sys
import torch
import rdkit
from moses.script_utils import add_train_args, read_smiles_csv, set_seed
from moses.models_storage import ModelsStorage
from moses.dataset import get_dataset
lg = rdkit.RDLogger.logger()
lg.setLevel(rdkit.RDLogger.CRITICAL)
MODELS = ModelsStorage()
def get_parser():
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(
title='Models trainer script', description='available models'
)
for model in MODELS.get_model_names():
add_train_args(
MODELS.get_model_train_parser(model)(
subparsers.add_parser(model)
)
)
return parser
def main(model, config):
set_seed(config.seed)
device = torch.device(config.device)
if config.config_save is not None:
torch.save(config, config.config_save)
# For CUDNN to work properly
if device.type.startswith('cuda'):
torch.cuda.set_device(device.index or 0)
if config.train_load is None:
train_data = get_dataset('train')
else:
train_data = read_smiles_csv(config.train_load)
if config.val_load is None:
val_data = get_dataset('test')
else:
val_data = read_smiles_csv(config.val_load)
trainer = MODELS.get_model_trainer(model)(config)
if config.vocab_load is not None:
assert os.path.exists(config.vocab_load), \
'vocab_load path does not exist!'
vocab = torch.load(config.vocab_load)
else:
vocab = trainer.get_vocabulary(train_data)
if config.vocab_save is not None:
torch.save(vocab, config.vocab_save)
model = MODELS.get_model_class(model)(vocab, config).to(device)
trainer.fit(model, train_data, val_data)
model = model.to('cpu')
torch.save(model.state_dict(), config.model_save)
if __name__ == '__main__':
parser = get_parser()
config = parser.parse_args()
model = sys.argv[1]
main(model, config)