Skip to content

Commit

Permalink
Merge pull request deepchem#1621 from VIGS25/split-transform-order
Browse files Browse the repository at this point in the history
Swapping Split-Transform order
  • Loading branch information
Bharath Ramsundar authored Jun 21, 2019
2 parents 4e382ee + 2a33962 commit 7ca3a11
Show file tree
Hide file tree
Showing 13 changed files with 234 additions and 114 deletions.
59 changes: 42 additions & 17 deletions deepchem/molnet/load_function/bace_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,17 @@ def load_bace_regression(featurizer='ECFP',
tasks=bace_tasks, smiles_field="mol", featurizer=featurizer)

dataset = loader.featurize(dataset_file, shard_size=8192)
# Initialize transformers
transformers = [
deepchem.trans.NormalizationTransformer(
transform_y=True, dataset=dataset, move_mean=move_mean)
]
if split is None:
# Initialize transformers
transformers = [
deepchem.trans.NormalizationTransformer(
transform_y=True, dataset=dataset, move_mean=move_mean)
]

logger.info("About to transform data")
for transformer in transformers:
dataset = transformer.transform(dataset)
logger.info("Split is None, about to transform data")
for transformer in transformers:
dataset = transformer.transform(dataset)

if split == None:
return bace_tasks, (dataset, None, None), transformers

splitters = {
Expand All @@ -76,8 +76,20 @@ def load_bace_regression(featurizer='ECFP',
'scaffold': deepchem.splits.ScaffoldSplitter()
}
splitter = splitters[split]
logger.info("About to split data using {} splitter".format(split))
train, valid, test = splitter.train_valid_test_split(dataset)

transformers = [
deepchem.trans.NormalizationTransformer(
transform_y=True, dataset=train, move_mean=move_mean)
]

logger.info("About to transform data.")
for transformer in transformers:
train = transformer.transform(train)
valid = transformer.transform(valid)
test = transformer.transform(test)

if reload:
deepchem.utils.save.save_dataset_to_disk(save_dir, train, valid, test,
transformers)
Expand Down Expand Up @@ -122,26 +134,39 @@ def load_bace_classification(featurizer='ECFP', split='random', reload=True):
tasks=bace_tasks, smiles_field="mol", featurizer=featurizer)

dataset = loader.featurize(dataset_file, shard_size=8192)
# Initialize transformers
transformers = [
deepchem.trans.BalancingTransformer(transform_w=True, dataset=dataset)
]

logger.info("About to transform data")
for transformer in transformers:
dataset = transformer.transform(dataset)
if split is None:
# Initialize transformers
transformers = [
deepchem.trans.BalancingTransformer(transform_w=True, dataset=dataset)
]

logger.info("Split is None, about to transform data")
for transformer in transformers:
dataset = transformer.transform(dataset)

if split == None:
return bace_tasks, (dataset, None, None), transformers

splitters = {
'index': deepchem.splits.IndexSplitter(),
'random': deepchem.splits.RandomSplitter(),
'scaffold': deepchem.splits.ScaffoldSplitter()
}

splitter = splitters[split]
logger.info("About to split data using {} splitter".format(split))
train, valid, test = splitter.train_valid_test_split(dataset)

transformers = [
deepchem.trans.BalancingTransformer(transform_w=True, dataset=train)
]

logger.info("About to transform data.")
for transformer in transformers:
train = transformer.transform(train)
valid = transformer.transform(valid)
test = transformer.transform(test)

if reload:
deepchem.utils.save.save_dataset_to_disk(save_dir, train, valid, test,
transformers)
Expand Down
12 changes: 10 additions & 2 deletions deepchem/molnet/load_function/bbbc_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

def load_bbbc001(split='index', reload=True):
"""Load BBBC001 dataset
This dataset contains 6 images of human HT29 colon cancer cells. The task is to learn to predict the cell counts in these images. This dataset is too small to serve to train algorithms, but might serve as a good test dataset. https://data.broadinstitute.org/bbbc/BBBC001/
"""
# Featurize BBBC001 dataset
Expand Down Expand Up @@ -57,6 +57,8 @@ def load_bbbc001(split='index', reload=True):
dataset = deepchem.data.DiskDataset.from_numpy(dataset.X, y)

if split == None:
transformers = []
logger.info("Split is None, no transformers used for the dataset.")
return bbbc001_tasks, (dataset, None, None), transformers

splitters = {
Expand All @@ -67,7 +69,9 @@ def load_bbbc001(split='index', reload=True):
raise ValueError("Only index and random splits supported.")
splitter = splitters[split]

logger.info("About to split dataset with {} splitter.".format(split))
train, valid, test = splitter.train_valid_test_split(dataset)
transformers = []
all_dataset = (train, valid, test)
if reload:
deepchem.utils.save.save_dataset_to_disk(save_dir, train, valid, test,
Expand All @@ -77,7 +81,7 @@ def load_bbbc001(split='index', reload=True):

def load_bbbc002(split='index', reload=True):
"""Load BBBC002 dataset
This dataset contains data corresponding to 5 samples of Drosophilia Kc167
cells. There are 10 fields of view for each sample, each an image of size
512x512. Ground truth labels contain cell counts for this dataset. Full
Expand Down Expand Up @@ -121,6 +125,8 @@ def load_bbbc002(split='index', reload=True):
dataset = deepchem.data.DiskDataset.from_numpy(dataset.X, y, ids=ids)

if split == None:
transformers = []
logger.info("Split is None, no transformers used for the dataset.")
return bbbc002_tasks, (dataset, None, None), transformers

splitters = {
Expand All @@ -131,8 +137,10 @@ def load_bbbc002(split='index', reload=True):
raise ValueError("Only index and random splits supported.")
splitter = splitters[split]

logger.info("About to split dataset with {} splitter.".format(split))
train, valid, test = splitter.train_valid_test_split(dataset)
all_dataset = (train, valid, test)
transformers = []
if reload:
deepchem.utils.save.save_dataset_to_disk(save_dir, train, valid, test,
transformers)
Expand Down
28 changes: 20 additions & 8 deletions deepchem/molnet/load_function/bbbp_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,17 @@ def load_bbbp(featurizer='ECFP', split='random', reload=True):
loader = deepchem.data.CSVLoader(
tasks=bbbp_tasks, smiles_field="smiles", featurizer=featurizer)
dataset = loader.featurize(dataset_file, shard_size=8192)
# Initialize transformers
transformers = [
deepchem.trans.BalancingTransformer(transform_w=True, dataset=dataset)
]

logger.info("About to transform data")
for transformer in transformers:
dataset = transformer.transform(dataset)
if split is None:
# Initialize transformers
transformers = [
deepchem.trans.BalancingTransformer(transform_w=True, dataset=dataset)
]

logger.info("Split is None, about to transform data")
for transformer in transformers:
dataset = transformer.transform(dataset)

if split == None:
return bbbp_tasks, (dataset, None, None), transformers

splitters = {
Expand All @@ -63,8 +64,19 @@ def load_bbbp(featurizer='ECFP', split='random', reload=True):
'scaffold': deepchem.splits.ScaffoldSplitter()
}
splitter = splitters[split]
logger.info("About to split data with {} splitter.".format(split))
train, valid, test = splitter.train_valid_test_split(dataset)

# Initialize transformers
transformers = [
deepchem.trans.BalancingTransformer(transform_w=True, dataset=train)
]

for transformer in transformers:
train = transformer.transform(train)
valid = transformer.transform(valid)
test = transformer.transform(test)

if reload:
deepchem.utils.save.save_dataset_to_disk(save_dir, train, valid, test,
transformers)
Expand Down
5 changes: 4 additions & 1 deletion deepchem/molnet/load_function/cell_counting_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

def load_cell_counting(split=None, reload=True):
"""Load Cell Counting dataset.
Loads the cell counting dataset from http://www.robots.ox.ac.uk/~vgg/research/counting/index_org.html.
"""
data_dir = deepchem.utils.get_data_dir()
Expand All @@ -43,6 +43,7 @@ def load_cell_counting(split=None, reload=True):
transformers = []

if split == None:
logger.info("Split is None, no transformers used.")
return cell_counting_tasks, (dataset, None, None), transformers

splitters = {
Expand All @@ -53,7 +54,9 @@ def load_cell_counting(split=None, reload=True):
raise ValueError("Only index and random splits supported.")
splitter = splitters[split]

logger.info("About to split dataset with {} splitter.".format(split))
train, valid, test = splitter.train_valid_test_split(dataset)
transformers = []
all_dataset = (train, valid, test)
if reload:
deepchem.utils.save.save_dataset_to_disk(save_dir, train, valid, test,
Expand Down
50 changes: 26 additions & 24 deletions deepchem/molnet/load_function/chembl_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,44 +80,46 @@ def load_chembl(shard_size=2000,

if split == "year":
logger.info("Featurizing train datasets")
train_dataset = loader.featurize(train_files, shard_size=shard_size)
train = loader.featurize(train_files, shard_size=shard_size)
logger.info("Featurizing valid datasets")
valid_dataset = loader.featurize(valid_files, shard_size=shard_size)
valid = loader.featurize(valid_files, shard_size=shard_size)
logger.info("Featurizing test datasets")
test_dataset = loader.featurize(test_files, shard_size=shard_size)
test = loader.featurize(test_files, shard_size=shard_size)
else:
dataset = loader.featurize(dataset_path, shard_size=shard_size)
# Initialize transformers
logger.info("About to transform data")
if split == "year":
transformers = [
deepchem.trans.NormalizationTransformer(
transform_y=True, dataset=train_dataset)
]
for transformer in transformers:
train = transformer.transform(train_dataset)
valid = transformer.transform(valid_dataset)
test = transformer.transform(test_dataset)
else:

if split is None:
transformers = [
deepchem.trans.NormalizationTransformer(
transform_y=True, dataset=dataset)
]

logger.info("Split is None, about to transform data.")
for transformer in transformers:
dataset = transformer.transform(dataset)

if split == None:
return chembl_tasks, (dataset, None, None), transformers

splitters = {
'index': deepchem.splits.IndexSplitter(),
'random': deepchem.splits.RandomSplitter(),
'scaffold': deepchem.splits.ScaffoldSplitter()
}
if split != "year":
splitters = {
'index': deepchem.splits.IndexSplitter(),
'random': deepchem.splits.RandomSplitter(),
'scaffold': deepchem.splits.ScaffoldSplitter()
}

splitter = splitters[split]
logger.info("Performing new split.")
train, valid, test = splitter.train_valid_test_split(dataset)

transformers = [
deepchem.trans.NormalizationTransformer(transform_y=True, dataset=train)
]

splitter = splitters[split]
logger.info("Performing new split.")
train, valid, test = splitter.train_valid_test_split(dataset)
logger.info("About to transform data.")
for transformer in transformers:
train = transformer.transform(train)
valid = transformer.transform(valid)
test = transformer.transform(test)

if reload:
deepchem.utils.save.save_dataset_to_disk(save_dir, train, valid, test,
Expand Down
30 changes: 21 additions & 9 deletions deepchem/molnet/load_function/clearance_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,17 @@ def load_clearance(featurizer='ECFP',
tasks=clearance_tasks, smiles_field="smiles", featurizer=featurizer)
dataset = loader.featurize(dataset_file, shard_size=8192)

# Initialize transformers
transformers = [
deepchem.trans.NormalizationTransformer(
transform_y=True, dataset=dataset, move_mean=move_mean)
]
if split is None:
# Initialize transformers
transformers = [
deepchem.trans.NormalizationTransformer(
transform_y=True, dataset=dataset, move_mean=move_mean)
]

logger.info("About to transform data")
for transformer in transformers:
dataset = transformer.transform(dataset)
logger.info("Split is None, about to transform data")
for transformer in transformers:
dataset = transformer.transform(dataset)

if split == None:
return clearance_tasks, (dataset, None, None), transformers

splitters = {
Expand All @@ -73,8 +73,20 @@ def load_clearance(featurizer='ECFP',
'scaffold': deepchem.splits.ScaffoldSplitter()
}
splitter = splitters[split]
logger.info("About to split data with {} splitter.".format(split))
train, valid, test = splitter.train_valid_test_split(dataset)

transformers = [
deepchem.trans.NormalizationTransformer(
transform_y=True, dataset=train, move_mean=move_mean)
]

logger.info("About to transform data")
for transformer in transformers:
train = transformer.transform(train)
valid = transformer.transform(valid)
test = transformer.transform(test)

if reload:
deepchem.utils.save.save_dataset_to_disk(save_dir, train, valid, test,
transformers)
Expand Down
27 changes: 18 additions & 9 deletions deepchem/molnet/load_function/clintox_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,15 @@ def load_clintox(featurizer='ECFP', split='index', reload=True):
dataset = loader.featurize(dataset_file, shard_size=8192)

# Transform clintox dataset
logger.info("About to transform clintox dataset.")
transformers = [
deepchem.trans.BalancingTransformer(transform_w=True, dataset=dataset)
]
for transformer in transformers:
dataset = transformer.transform(dataset)
if split is None:
transformers = [
deepchem.trans.BalancingTransformer(transform_w=True, dataset=dataset)
]

# Split clintox dataset
logger.info("About to split clintox dataset.")
logger.info("Split is None, about to transform data.")
for transformer in transformers:
dataset = transformer.transform(dataset)

if split == None:
return clintox_tasks, (dataset, None, None), transformers

splitters = {
Expand All @@ -73,8 +71,19 @@ def load_clintox(featurizer='ECFP', split='index', reload=True):
'scaffold': deepchem.splits.ScaffoldSplitter()
}
splitter = splitters[split]
logger.info("About to split data with {} splitter.".format(split))
train, valid, test = splitter.train_valid_test_split(dataset)

transformers = [
deepchem.trans.BalancingTransformer(transform_w=True, dataset=train)
]

logger.info("About to transform data.")
for transformer in transformers:
train = transformer.transform(train)
valid = transformer.transform(valid)
test = transformer.transform(test)

if reload:
deepchem.utils.save.save_dataset_to_disk(save_dir, train, valid, test,
transformers)
Expand Down
Loading

0 comments on commit 7ca3a11

Please sign in to comment.