forked from deepchem/deepchem
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathchembl_tf_models.py
68 lines (56 loc) · 1.74 KB
/
chembl_tf_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""
Script that trains Tensorflow Multitask models on ChEMBL dataset.
"""
from __future__ import print_function
from __future__ import division
from __future__ import unicode_literals
import os
import tempfile
import shutil
import numpy as np
import deepchem as dc
from deepchem.molnet import load_chembl
# Set numpy seed
np.random.seed(123)
###Load data###
shard_size = 2000
print("About to load ChEMBL data.")
chembl_tasks, datasets, transformers = load_chembl(
shard_size=shard_size, featurizer="ECFP", set="5thresh", split="random")
train_dataset, valid_dataset, test_dataset = datasets
print("ChEMBL_tasks")
print(len(chembl_tasks))
print("Number of compounds in train set")
print(len(train_dataset))
print("Number of compounds in validation set")
print(len(valid_dataset))
print("Number of compounds in test set")
print(len(test_dataset))
###Create model###
n_layers = 3
nb_epoch = 10
model = dc.models.MultitaskRegressor(
len(chembl_tasks),
train_dataset.get_data_shape()[0],
layer_sizes=[1000] * n_layers,
dropouts=[.25] * n_layers,
weight_init_stddevs=[.02] * n_layers,
bias_init_consts=[1.] * n_layers,
learning_rate=.0003,
weight_decay_penalty=.0001,
batch_size=100,
seed=123,
verbosity="high")
#Use R2 classification metric
metric = dc.metrics.Metric(dc.metrics.pearson_r2_score, task_averager=np.mean)
print("Training model")
model.fit(train_dataset, nb_epoch=nb_epoch)
train_scores = model.evaluate(train_dataset, [metric], transformers)
valid_scores = model.evaluate(valid_dataset, [metric], transformers)
test_scores = model.evaluate(test_dataset, [metric], transformers)
print("Train scores")
print(train_scores)
print("Validation scores")
print(valid_scores)
print("Test scores")
print(test_scores)