Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
jkwang93 committed May 20, 2021
1 parent b8f878b commit 4b59b09
Show file tree
Hide file tree
Showing 31 changed files with 104,828 additions and 0 deletions.
8 changes: 8 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions .idea/MCMG.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions .idea/deployment.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

92 changes: 92 additions & 0 deletions .idea/inspectionProfiles/Project_Default.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/inspectionProfiles/profiles_settings.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

94 changes: 94 additions & 0 deletions 2_generator_Transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#!/usr/bin/env python
import argparse

import torch
import time

from models.model_MCMG import transformer_RL
from MCMG_utils.data_structs import Vocabulary
from MCMG_utils.utils import seq_to_smiles
import pandas as pd




def Transformer_generator(restore_prior_from='output/Prior.ckpt',
save_file='test.csv',
batch_size=128,
n_steps=5000,


):
voc = Vocabulary(init_from_file="data/Voc_RE1")

start_time = time.time()

Prior = transformer_RL(voc, d_model, nhead, num_decoder_layers,
dim_feedforward, max_seq_length,
pos_dropout, trans_dropout)

Prior.decodertf.eval()

# By default restore middle_RNN to same model as Prior, but can restore from already trained middle_RNN too.
# Saved models are partially on the GPU, but if we dont have cuda enabled we can remap these
# to the CPU.
if torch.cuda.is_available():
Prior.decodertf.load_state_dict(torch.load(restore_prior_from, map_location={'cuda:0': 'cuda:0'}))
else:
Prior.decodertf.load_state_dict(
torch.load(restore_prior_from, map_location=lambda storage, loc: storage))

Prior.decodertf.to(device)

smile_list = []

for i in range(n_steps):
seqs = Prior.generate(batch_size, max_length=140, con_token_list=token_list)

smiles = seq_to_smiles(seqs, voc)

smile_list.extend(smiles)

print('step: ', i)

smile_list = pd.DataFrame(smile_list)
smile_list.to_csv(save_file, header=False, index=False)


if __name__ == "__main__":
max_seq_length = 140
# num_tokens=71
# vocab_size=71
d_model = 128
# num_encoder_layers = 6
num_decoder_layers = 12
dim_feedforward = 512
nhead = 8
pos_dropout = 0.1
trans_dropout = 0.1
n_warmup_steps = 500

num_epochs = 600
batch_size = 128

n_steps = 5000

token_list = ['is_DRD2', 'high_QED', 'good_SA']

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

parser = argparse.ArgumentParser(description="Main script for running the model")
parser.add_argument('--num-steps', action='store', dest='n_steps', type=int,
default=500)
parser.add_argument('--batch-size', action='store', dest='batch_size', type=int,
default=128)
parser.add_argument('--prior', action='store', dest='restore_prior_from',
default='./data/Prior.ckpt',
help='Path to an c-Transformer checkpoint file to use as a Prior')

parser.add_argument('--save_molecules_path', action='store', dest='save_file',
default='test.csv')

arg_dict = vars(parser.parse_args())

Transformer_generator(**arg_dict)
75 changes: 75 additions & 0 deletions 3_train_middle_model_dm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#!/usr/bin/env python
import argparse

import torch
from torch.utils.data import DataLoader
from rdkit import Chem
from rdkit import rdBase
from tqdm import tqdm

from MCMG_utils.data_structs import MolData, Vocabulary
from models.model_rnn import RNN
from MCMG_utils.utils import decrease_learning_rate
rdBase.DisableLog('rdApp.error')

def train_middle(train_data, save_model='./DM.ckpt'):
"""Trains the Prior RNN"""

# Read vocabulary from a file
voc = Vocabulary(init_from_file="data/Voc_RE1")

# Create a Dataset from a SMILES file
moldata = MolData(train_data, voc)
data = DataLoader(moldata, batch_size=128, shuffle=True, drop_last=True,
collate_fn=MolData.collate_fn)

Prior = RNN(voc)


optimizer = torch.optim.Adam(Prior.rnn.parameters(), lr = 0.001)
for epoch in range(1, 9):

for step, batch in tqdm(enumerate(data), total=len(data)):

# Sample from DataLoader
seqs = batch.long()

# Calculate loss
log_p = Prior.likelihood(seqs)
loss = - log_p.mean()

# Calculate gradients and take a step
optimizer.zero_grad()
loss.backward()
optimizer.step()

# Every 500 steps we decrease learning rate and print some information
if step % 500 == 0 and step != 0:
decrease_learning_rate(optimizer, decrease_by=0.03)
tqdm.write("*" * 50)
print(loss.cpu().data)
tqdm.write("Epoch {:3d} step {:3d} loss: {:5.2f}\n".format(epoch, step, loss.cpu().data))
seqs, likelihood, _ = Prior.sample(128)
valid = 0
for i, seq in enumerate(seqs.cpu().numpy()):
smile = voc.decode(seq)
if Chem.MolFromSmiles(smile):
valid += 1
if i < 5:
tqdm.write(smile)
tqdm.write("\n{:>4.1f}% valid SMILES".format(100 * valid / len(seqs)))
tqdm.write("*" * 50 + "\n")
torch.save(Prior.rnn.state_dict(), save_model)

# Save the Prior
torch.save(Prior.rnn.state_dict(), save_model)

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Main script for running the model")
parser.add_argument('--train-data', action='store', dest='train_data')
parser.add_argument('--save-middle-path', action='store', dest='save_dir',
help='Path and name of middle model.')

arg_dict = vars(parser.parse_args())

train_middle(**arg_dict)
Loading

0 comments on commit 4b59b09

Please sign in to comment.