Skip to content

Commit

Permalink
Merge remote-tracking branch 'remotes/mine/textcnn'
Browse files Browse the repository at this point in the history
  • Loading branch information
miaecle committed Jan 25, 2018
2 parents 88138d3 + 184278e commit a2be59a
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 19 deletions.
26 changes: 14 additions & 12 deletions deepchem/models/tensorgraph/models/graph_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,17 @@ def default_generator(self,
def predict_on_generator(self, generator, transformers=[], outputs=None):
out = super(WeaveTensorGraph, self).predict_on_generator(
generator,
transformers=transformers,
transformers=[],
outputs=outputs)
if outputs is None:
outputs = self.outputs
if len(outputs) == 1:
return out
else:
return np.stack(out, axis=1)
if len(outputs) > 1:
out = np.stack(out, axis=1)

out = undo_transforms(out, transformers)
return out



class DTNNTensorGraph(TensorGraph):

Expand Down Expand Up @@ -345,7 +348,6 @@ def default_generator(self,

yield feed_dict

'''
def predict(self, dataset, transformers=[], outputs=None):
if outputs is None:
outputs = self.outputs
Expand All @@ -357,7 +359,6 @@ def predict(self, dataset, transformers=[], outputs=None):
return retval
retval = np.concatenate(retval, axis=-1)
return undo_transforms(retval, transformers)
'''

class DAGTensorGraph(TensorGraph):

Expand Down Expand Up @@ -511,14 +512,15 @@ def default_generator(self,
def predict_on_generator(self, generator, transformers=[], outputs=None):
out = super(DAGTensorGraph, self).predict_on_generator(
generator,
transformers=transformers,
transformers=[],
outputs=outputs)
if outputs is None:
outputs = self.outputs
if len(outputs) == 1:
return out
else:
return np.stack(out, axis=1)
if len(outputs) > 1:
out = np.stack(out, axis=1)

out = undo_transforms(out, transformers)
return out

class PetroskiSuchTensorGraph(TensorGraph):
"""
Expand Down
11 changes: 6 additions & 5 deletions deepchem/models/tensorgraph/models/text_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,11 +274,12 @@ def smiles_to_seq(self, smiles):
def predict_on_generator(self, generator, transformers=[], outputs=None):
out = super(TextCNNTensorGraph, self).predict_on_generator(
generator,
transformers=transformers,
transformers=[],
outputs=outputs)
if outputs is None:
outputs = self.outputs
if len(outputs) == 1:
return out
else:
return np.stack(out, axis=1)
if len(outputs) > 1:
out = np.stack(out, axis=1)

out = undo_transforms(out, transformers)
return out
2 changes: 1 addition & 1 deletion deepchem/molnet/preset_hyper_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@
hps['dtnn'] = {
'batch_size': 64,
'nb_epoch': 100,
'learning_rate': 0.0005,
'learning_rate': 0.001,
'n_embedding': 50,
'n_distance': 170,
'seed': 123
Expand Down
2 changes: 1 addition & 1 deletion deepchem/molnet/run_benchmark_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def benchmark_regression(train_dataset,
random_seed=seed,
output_activation=False,
use_queue=False,
mode='regression')
mode='regression')

elif model_name == 'dag_regression':
batch_size = hyper_parameters['batch_size']
Expand Down

0 comments on commit a2be59a

Please sign in to comment.