Skip to content

Commit

Permalink
tests and reproducibility fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Artur Kadurin committed Apr 23, 2020
1 parent 78cf286 commit e2e4d3f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 26 deletions.
40 changes: 15 additions & 25 deletions moses/baselines/combinatorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,11 @@ def __init__(self, n_jobs=1, mode=0):
Arguments:
n_jobs: number of processes for training
mode: sampling mode
last bit sets sampling connection point or fragment first
second bit sets sampling connection between two fragments
0: Sample connection point, sample from unique reactions
1: Sample fragment first, sample from unique reactions
2: Sample connection point, sample from all possible reactions
3: Sample fragment first, sample from all possible reactions
0: Sample fragment then connection
1: Sample connection point then fragments
"""
self.n_jobs = n_jobs
self.mode = mode
self.set_mode(mode)
self.fitted = False

def fit(self, data):
Expand Down Expand Up @@ -132,7 +128,7 @@ def load(cls, path):
return model

def set_mode(self, mode):
if mode not in [0, 1, 2, 3]:
if mode not in [0, 1]:
raise ValueError('Incorrect mode value: %s' % mode)
self.mode = mode

Expand Down Expand Up @@ -180,13 +176,11 @@ def generate_one(self, seed=None):
if mol is None:
mol = self.sample_fragment(counts_masked)
else:
if self.mode & 1: # Sample fragment first
con_filter = self.get_connection_filter(connections_mol)
else: # Choose connection atom first
if self.mode == 1: # Choose connection atom first
atom_mol = np.random.choice(connections_mol)
connections_mol = [atom_mol]
con_filter = 2**atom_mol.GetIsotope()

con_filter = self.get_connection_filter(connections_mol)
# Mask fragments with possible reactions
counts_masked = counts_masked[
counts_masked['connection_rules'] & con_filter > 0
Expand All @@ -198,13 +192,8 @@ def generate_one(self, seed=None):
connections_fragment
)

if self.mode & 2: # Sample weighted connection
c_i = np.random.choice(len(possible_connections))
a1, a2 = possible_connections[c_i]
else: # Sample from unique connections
possible_connections = list(set(possible_connections))
c_i = np.random.choice(len(possible_connections))
a1, a2 = possible_connections[c_i]
c_i = np.random.choice(len(possible_connections))
a1, a2 = possible_connections[c_i]

# Connect a new fragment to the molecule
mol = self.connect_mols(mol, fragment, a1, a2)
Expand All @@ -217,13 +206,14 @@ def generate_one(self, seed=None):
smiles = Chem.MolToSmiles(mol)
return smiles

def generate(self, n, seed, mode=0):
def generate(self, n, seed=1, mode=0, verbose=False):
self.set_mode(mode)
generator = (self.generate_one(seed) for i in range(n))
if self.verbose:
seeds = range((seed - 1) * n, seed * n)
if verbose:
print('generating...')
generator = tqdm(generator, total=n)
return list(generator)
seeds = tqdm(seeds, total=n)
samples = mapper(self.n_jobs)(self.generate_one, seeds)
return samples

def get_connection_rule(self, fragment):
"""
Expand Down Expand Up @@ -277,7 +267,7 @@ def get_connection_points(mol):
return atoms

@staticmethod
def filter_connections(atoms1, atoms2):
def filter_connections(atoms1, atoms2, unique=True):
possible_connections = []
for a1 in atoms1:
i1 = a1.GetIsotope()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_hmm(self):

def test_combinatorial(self):
model = CombinatorialGenerator()
model.fit(self.train[:10])
model.fit(self.train[:100])
sample_original = model.generate_one(1)
with tempfile.NamedTemporaryFile() as f:
model.save(f.name)
Expand Down

0 comments on commit e2e4d3f

Please sign in to comment.