forked from molecularsets/moses
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added the latentgan model (molecularsets#59)
* Added LatentGAN * attempt to remove submodule in LatentGAN * attempt to remove submodule from LatentGAN * removed ddc_pub submodule * Added refactored LatentGAN * attempt to remove submodule * removed data directory from git during development * untracking pointer files from git-lfs * re-added data files from origin * merged with changes from main MOSES repo * added the latentgan model * refactored latentgan code to comply with travis ci * removed temporary files and too heavy pretrained models
- Loading branch information
Showing
9 changed files
with
627 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,3 +4,4 @@ | |
build/ | ||
dist/ | ||
moses.egg-info/ | ||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
LatentGAN | ||
========= | ||
<p align="center"> | ||
<img src="../../images/LatentGAN.png"> | ||
</p> | ||
|
||
LatentGAN [1] with heteroencoder trained on ChEMBL 25 [2], which encodes SMILES strings into latent vector representations of size 512. A Wasserstein Generative Adversarial network with Gradient Penalty [3] is then trained to generate latent vectors resembling that of the training set, which are then decoded using the heteroencoder. This model uses the Deep-Drug-Coder heteroencoder implementation [4]. | ||
|
||
|
||
Important! | ||
========== | ||
Currently, the Deep-Drug-Coder [4] and its dependency package molvecgen [5] are not available in pypi, these have to be installed from there respective repositories (links provided below). | ||
|
||
The pretrained models of the LatentGAN are currently not shared in this repository due to file size constraints. These will be added in the near future. | ||
|
||
## References | ||
|
||
[1] [A De Novo Molecular Generation Method Using Latent Vector Based Generative Adversarial Network](https://chemrxiv.org/articles/A_De_Novo_Molecular_Generation_Method_Using_Latent_Vector_Based_Generative_Adversarial_Network/8299544) | ||
|
||
[2] [ChEMBL](https://www.ebi.ac.uk/chembl/) | ||
|
||
[3] [Improved training of Wasserstein GANs](https://arxiv.org/abs/1704.00028) | ||
|
||
[4] [Deep-Drug-Coder](https://github.com/pcko1/Deep-Drug-Coder) | ||
|
||
[5] [molvecgen](https://github.com/EBjerrum/molvecgen) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .config import get_parser as latentGAN_parser | ||
from .model import LatentGAN | ||
from .trainer import LatentGANTrainer | ||
|
||
__all__ = ['latentGAN_parser', 'LatentGAN', 'LatentGANTrainer'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import argparse | ||
|
||
|
||
def get_parser(parser=None): | ||
if parser is None: | ||
parser = argparse.ArgumentParser() | ||
|
||
# Model | ||
model_arg = parser.add_argument_group('Model') | ||
model_arg.add_argument("--heteroencoder_version", type=str, default='new', | ||
help="Which heteroencoder model version to use") | ||
# Train | ||
train_arg = parser.add_argument_group('Training') | ||
|
||
train_arg.add_argument('--gp', type=int, default=10, | ||
help='Gradient Penalty Coefficient') | ||
train_arg.add_argument('--n_critic', type=int, default=5, | ||
help='Ratio of discriminator to' | ||
' generator training frequency') | ||
train_arg.add_argument('--train_epochs', type=int, default=2000, | ||
help='Number of epochs for model training') | ||
train_arg.add_argument('--n_batch', type=int, default=64, | ||
help='Size of batch') | ||
train_arg.add_argument('--lr', type=float, default=0.0002, | ||
help='Learning rate') | ||
train_arg.add_argument('--b1', type=float, default=0.5, | ||
help='Adam optimizer parameter beta 1') | ||
train_arg.add_argument('--b2', type=float, default=0.999, | ||
help='Adam optimizer parameter beta 2') | ||
train_arg.add_argument('--step_size', type=int, default=10, | ||
help='Period of learning rate decay') | ||
train_arg.add_argument('--latent_vector_dim', type=int, default=512, | ||
help='Size of latentgan vector') | ||
train_arg.add_argument('--gamma', type=float, default=1, | ||
help='Multiplicative factor of' | ||
' learning rate decay') | ||
train_arg.add_argument('--n_jobs', type=int, default=1, | ||
help='Number of threads') | ||
train_arg.add_argument('--n_workers', type=int, default=1, | ||
help='Number of workers for DataLoaders') | ||
|
||
# Arguments used if training a new heteroencoder | ||
heteroencoder_arg = parser.add_argument_group('heteroencoder') | ||
|
||
heteroencoder_arg.add_argument('--heteroencoder_layer_dim', type=int, | ||
default=512, | ||
help='Layer size for heteroencoder ' | ||
'(if training new heteroencoder)') | ||
heteroencoder_arg.add_argument('--heteroencoder_noise_std', type=float, | ||
default=0.1, | ||
help='Noise amplitude for heteroencoder') | ||
heteroencoder_arg.add_argument('--heteroencoder_dec_layers', type=int, | ||
default=4, | ||
help='Number of decoding layers' | ||
' for heteroencoder') | ||
heteroencoder_arg.add_argument('--heteroencoder_batch_size', | ||
type=int, default=128, | ||
help='Batch size for heteroencoder') | ||
heteroencoder_arg.add_argument('--heteroencoder_epochs', type=int, | ||
default=100, | ||
help='Number of epochs for heteroencoder') | ||
heteroencoder_arg.add_argument('--heteroencoder_lr', type=float, | ||
default=1e-3, | ||
help='learning rate for heteroencoder') | ||
heteroencoder_arg.add_argument('--heteroencoder_mini_epochs', type=int, | ||
default=10, | ||
help='How many sub-epochs to ' | ||
'split each epoch for heteroencoder') | ||
heteroencoder_arg.add_argument('--heteroencoder_lr_decay', | ||
default=True, action='store_false', | ||
help='Use learning rate decay ' | ||
'for heteroencoder ') | ||
heteroencoder_arg.add_argument('--heteroencoder_patience', type=int, | ||
default=100, | ||
help='Patience for adaptive learning ' | ||
'rate for heteroencoder') | ||
heteroencoder_arg.add_argument('--heteroencoder_lr_decay_start', type=int, | ||
default=500, | ||
help='Which sub-epoch to start decaying ' | ||
'learning rate for heteroencoder ') | ||
heteroencoder_arg.add_argument('--heteroencoder_save_period', type=int, | ||
default=100, | ||
help='How often in sub-epochs to ' | ||
'save model checkpoints for' | ||
' heteroencoder') | ||
|
||
return parser | ||
|
||
|
||
def get_config(): | ||
parser = get_parser() | ||
return parser.parse_known_args()[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,226 @@ | ||
import torch.nn as nn | ||
import numpy as np | ||
import torch | ||
from ddc_pub import ddc_v3 as ddc | ||
import os | ||
from rdkit import Chem | ||
import sys | ||
from torch.utils import data | ||
import torch.autograd as autograd | ||
|
||
|
||
class LatentGAN(nn.Module): | ||
def __init__(self, vocabulary, config): | ||
super(LatentGAN, self).__init__() | ||
self.vocabulary = vocabulary | ||
self.generator = Generator() | ||
self.model_version = config.heteroencoder_version | ||
self.discriminator = Discriminator() | ||
self.sample_decoder = None | ||
self.model_loaded = False | ||
self.new_batch_size = 256 | ||
# init params | ||
cuda = True if torch.cuda.is_available() else False | ||
if cuda: | ||
self.discriminator.cuda() | ||
self.generator.cuda() | ||
self.Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor | ||
|
||
def forward(self, n_batch): | ||
out = self.sample(n_batch) | ||
return out | ||
|
||
def encode_smiles(self, smiles_in, encoder=None): | ||
|
||
model = load_model(model_version=encoder) | ||
|
||
# MUST convert SMILES to binary mols for the model to accept them | ||
# (it re-converts them to SMILES internally) | ||
mols_in = [Chem.rdchem.Mol.ToBinary(Chem.MolFromSmiles(smiles)) | ||
for smiles in smiles_in] | ||
latent = model.transform(model.vectorize(mols_in)) | ||
|
||
return latent.tolist() | ||
|
||
def compute_gradient_penalty(self, real_samples, | ||
fake_samples, discriminator): | ||
"""Calculates the gradient penalty loss for WGAN GP""" | ||
# Random weight term for interpolation between real and fake samples | ||
alpha = self.Tensor(np.random.random((real_samples.size(0), 1))) | ||
|
||
# Get random interpolation between real and fake samples | ||
interpolates = (alpha * real_samples + | ||
((1 - alpha) * fake_samples)).requires_grad_(True) | ||
d_interpolates = discriminator(interpolates) | ||
fake = self.Tensor(real_samples.shape[0], 1).fill_(1.0) | ||
|
||
# Get gradient w.r.t. interpolates | ||
gradients = autograd.grad( | ||
outputs=d_interpolates, | ||
inputs=interpolates, | ||
grad_outputs=fake, | ||
create_graph=True, | ||
retain_graph=True, | ||
only_inputs=True, | ||
)[0] | ||
gradients = gradients.view(gradients.size(0), -1) | ||
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() | ||
|
||
return gradient_penalty | ||
|
||
@property | ||
def device(self): | ||
return next(self.parameters()).device | ||
|
||
def sample(self, n_batch, max_length=100): | ||
if not self.model_loaded: | ||
# Checking for first batch of model to only load model once | ||
print('Heteroencoder for Sampling Loaded') | ||
self.sample_decoder = load_model(model_version=self.model_version) | ||
# load generator | ||
|
||
self.Gen = self.generator | ||
self.Gen.eval() | ||
|
||
self.D = self.discriminator | ||
torch.no_grad() | ||
cuda = True if torch.cuda.is_available() else False | ||
if cuda: | ||
self.Gen.cuda() | ||
self.D.cuda() | ||
self.S = Sampler(generator=self.Gen) | ||
self.model_loaded = True | ||
|
||
if n_batch <= 256: | ||
print('Batch size of {} detected. Decoding ' | ||
'performs poorly when Batch size != 256. \ | ||
Setting batch size to 256'.format(n_batch)) | ||
# Sampling performs very poorly on default sampling batch parameters. | ||
# This takes care of the default scenario. | ||
if n_batch == 32: | ||
n_batch = 256 | ||
|
||
latent = self.S.sample(n_batch) | ||
sanitycheck = self.D(latent) | ||
print('mean latent values') | ||
print(torch.mean(latent)) | ||
print('var latent values') | ||
print(torch.var(latent)) | ||
print('generator loss of sample') | ||
print(-torch.mean(sanitycheck)) | ||
latent = latent.detach().cpu().numpy() | ||
|
||
if self.new_batch_size != n_batch: | ||
# The batch decoder creates a new instance of the decoder | ||
# every time a new batch size is given, e.g. for the | ||
# final batch of the generation. | ||
self.new_batch_size = n_batch | ||
self.sample_decoder.batch_input_length = self.new_batch_size | ||
lat = latent | ||
|
||
sys.stdout.flush() | ||
|
||
smi, _ = self.sample_decoder.predict_batch(lat, temp=0) | ||
return smi | ||
|
||
|
||
def load_model(model_version=None): | ||
# Import model | ||
currentDirectory = os.getcwd() | ||
|
||
if model_version == 'chembl': | ||
model_name = 'chembl_pretrained' | ||
elif model_version == 'moses': | ||
model_name = 'moses_pretrained' | ||
elif model_version == 'new': | ||
model_name = 'new_model' | ||
else: | ||
print('No predefined model of that name found. ' | ||
'using the default pre-trained MOSES heteroencoder') | ||
model_name = 'moses_pretrained' | ||
|
||
path = '{}/moses/latentgan/heteroencoder_models/{}' \ | ||
.format(currentDirectory, model_name) | ||
print("Loading heteroencoder model titled {}".format(model_version)) | ||
print("Path to model file: {}".format(path)) | ||
model = ddc.DDC(model_name=path) | ||
sys.stdout.flush() | ||
|
||
return model | ||
|
||
|
||
class LatentMolsDataset(data.Dataset): | ||
def __init__(self, latent_space_mols): | ||
self.data = latent_space_mols | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
def __getitem__(self, index): | ||
return self.data[index] | ||
|
||
|
||
class Discriminator(nn.Module): | ||
def __init__(self, data_shape=(1, 512)): | ||
super(Discriminator, self).__init__() | ||
self.data_shape = data_shape | ||
|
||
self.model = nn.Sequential( | ||
nn.Linear(int(np.prod(self.data_shape)), 512), | ||
nn.LeakyReLU(0.2, inplace=True), | ||
nn.Linear(512, 256), | ||
nn.LeakyReLU(0.2, inplace=True), | ||
nn.Linear(256, 1), | ||
) | ||
|
||
def forward(self, mol): | ||
validity = self.model(mol) | ||
return validity | ||
|
||
|
||
class Generator(nn.Module): | ||
def __init__(self, data_shape=(1, 512), latent_dim=None): | ||
super(Generator, self).__init__() | ||
self.data_shape = data_shape | ||
|
||
# latent dim of the generator is one of the hyperparams. | ||
# by default it is set to the prod of data_shapes | ||
self.latent_dim = int(np.prod(self.data_shape)) \ | ||
if latent_dim is None else latent_dim | ||
|
||
def block(in_feat, out_feat, normalize=True): | ||
layers = [nn.Linear(in_feat, out_feat)] | ||
if normalize: | ||
layers.append(nn.BatchNorm1d(out_feat, 0.8)) | ||
layers.append(nn.LeakyReLU(0.2, inplace=True)) | ||
return layers | ||
|
||
self.model = nn.Sequential( | ||
*block(self.latent_dim, 128, normalize=False), | ||
*block(128, 256), | ||
*block(256, 512), | ||
*block(512, 1024), | ||
nn.Linear(1024, int(np.prod(self.data_shape))), | ||
# nn.Tanh() # expecting latent vectors to be not normalized | ||
) | ||
|
||
def forward(self, z): | ||
out = self.model(z) | ||
return out | ||
|
||
|
||
class Sampler(object): | ||
""" | ||
Sampling the mols the generator. | ||
All scripts should use this class for sampling. | ||
""" | ||
|
||
def __init__(self, generator: Generator): | ||
self.G = generator | ||
|
||
def sample(self, n): | ||
# Sample noise as generator input | ||
z = torch.cuda.FloatTensor(np.random.uniform(-1, 1, | ||
(n, self.G.latent_dim))) | ||
# Generate a batch of mols | ||
return self.G(z) |
Oops, something went wrong.