forked from datamol-io/graphium
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_datamodule.py
389 lines (332 loc) · 16.2 KB
/
test_datamodule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
import unittest as ut
import numpy as np
import torch
import pandas as pd
import datamol as dm
import graphium
from graphium.utils.fs import rm, exists, get_size
from graphium.data import GraphOGBDataModule, MultitaskFromSmilesDataModule
TEMP_CACHE_DATA_PATH = "tests/temp_cache_0000"
class Test_DataModule(ut.TestCase):
def test_ogb_datamodule(self):
# other datasets are too large to be tested
dataset_names = ["ogbg-molhiv", "ogbg-molpcba", "ogbg-moltox21", "ogbg-molfreesolv"]
dataset_name = dataset_names[3]
# Setup the featurization
featurization_args = {}
featurization_args["atom_property_list_float"] = [] # ["weight", "valence"]
featurization_args["atom_property_list_onehot"] = ["atomic-number", "degree"]
# featurization_args["conformer_property_list"] = ["positions_3d"]
featurization_args["edge_property_list"] = ["bond-type-onehot"]
featurization_args["add_self_loop"] = False
featurization_args["use_bonds_weights"] = False
featurization_args["explicit_H"] = False
# Config for datamodule
task_specific_args = {}
task_specific_args["task_1"] = {"task_level": "graph", "dataset_name": dataset_name}
dm_args = {}
dm_args["cache_data_path"] = None
dm_args["featurization"] = featurization_args
dm_args["batch_size_training"] = 16
dm_args["batch_size_inference"] = 16
dm_args["num_workers"] = 0
dm_args["pin_memory"] = True
dm_args["featurization_n_jobs"] = 2
dm_args["featurization_progress"] = True
dm_args["featurization_backend"] = "loky"
dm_args["featurization_batch_size"] = 50
ds = GraphOGBDataModule(task_specific_args, **dm_args)
ds.prepare_data()
# Check the keys in the dataset
ds.setup(save_smiles_and_ids=False)
assert set(ds.train_ds[0].keys()) == {"features", "labels"}
ds.setup(save_smiles_and_ids=True)
assert set(ds.train_ds[0].keys()) == {"smiles", "mol_ids", "features", "labels"}
# test module
assert ds.num_edge_feats == 5
assert ds.num_node_feats == 50
assert len(ds) == 642
# test batch loader
batch = next(iter(ds.train_dataloader()))
assert len(batch["smiles"]) == 16
assert len(batch["labels"]["graph_task_1"]) == 16
assert len(batch["mol_ids"]) == 16
def test_none_filtering(self):
# Create the objects to filter
list_of_num = [ii for ii in range(100)]
list_of_str = [str(ii) for ii in list_of_num]
tuple_of_num = tuple(list_of_num)
array_of_num = np.asarray(list_of_num)
array_of_str = np.asarray(list_of_str)
tensor_of_num = torch.as_tensor(array_of_num)
arrays_of_num = np.stack([list_of_num, list_of_num, list_of_num], axis=1)
arrays_of_str = np.stack([list_of_str, list_of_str, list_of_str], axis=1)
tensors_of_num = torch.as_tensor(arrays_of_num)
dic = {"str": list_of_str, "num": list_of_num}
df = pd.DataFrame(dic)
df_shuffled = df.sample(frac=1)
series_num = df["num"]
series_num_shuffled = df_shuffled["num"]
# Create different indexes to use for filtering
all_idx_none = [[3, 17, 88], [22, 33, 44, 55, 66, 77, 88], [], np.arange(len(list_of_num))]
# Loop all the indexes and filter the objects.
for ii, idx_none in enumerate(all_idx_none):
msg = f"Failed for ii={ii}"
# Create the true filtered sequences
filtered_num = [ii for ii in range(100) if ii not in idx_none]
filtered_str = [str(ii) for ii in filtered_num]
assert len(filtered_num) == len(list_of_num) - len(idx_none)
assert len(filtered_str) == len(list_of_str) - len(idx_none)
# Filter the sequences from the Datamodule function
(
list_of_num_2,
list_of_str_2,
tuple_of_num_2,
array_of_num_2,
array_of_str_2,
tensor_of_num_2,
df_2,
df_shuffled_2,
dic_2,
arrays_of_num_2,
arrays_of_str_2,
tensors_of_num_2,
series_num_2,
series_num_shuffled_2,
) = graphium.data.MultitaskFromSmilesDataModule._filter_none_molecules(
idx_none,
list_of_num,
list_of_str,
tuple_of_num,
array_of_num,
array_of_str,
tensor_of_num,
df,
df_shuffled,
dic,
arrays_of_num,
arrays_of_str,
tensors_of_num,
series_num,
series_num_shuffled,
)
df_shuffled_2 = df_shuffled_2.sort_values(by="num", axis=0)
series_num_shuffled_2 = series_num_shuffled_2.sort_values(axis=0)
# Assert the filtering is done correctly
self.assertListEqual(list_of_num_2, filtered_num, msg=msg)
self.assertListEqual(list_of_str_2, filtered_str, msg=msg)
self.assertListEqual(list(tuple_of_num_2), filtered_num, msg=msg)
self.assertListEqual(array_of_num_2.tolist(), filtered_num, msg=msg)
self.assertListEqual(array_of_str_2.tolist(), filtered_str, msg=msg)
self.assertListEqual(tensor_of_num_2.tolist(), filtered_num, msg=msg)
for jj in range(arrays_of_num.shape[1]):
self.assertListEqual(arrays_of_num_2[:, jj].tolist(), filtered_num, msg=msg)
self.assertListEqual(arrays_of_str_2[:, jj].tolist(), filtered_str, msg=msg)
self.assertListEqual(tensors_of_num_2[:, jj].tolist(), filtered_num, msg=msg)
self.assertListEqual(dic_2["num"], filtered_num, msg=msg)
self.assertListEqual(dic_2["str"], filtered_str, msg=msg)
self.assertListEqual(df_2["num"].tolist(), filtered_num, msg=msg)
self.assertListEqual(df_2["str"].tolist(), filtered_str, msg=msg)
self.assertListEqual(series_num_2.tolist(), filtered_num, msg=msg)
# When the dataframe is shuffled, the lists are different because the filtering
# is done on the row indexes, not the dataframe indexes.
bool_to_check = (len(idx_none) == 0) or (len(idx_none) == len(df_shuffled))
self.assertIs(df_shuffled_2["num"].tolist() == filtered_num, bool_to_check, msg=msg)
self.assertIs(df_shuffled_2["str"].tolist() == filtered_str, bool_to_check, msg=msg)
self.assertIs(series_num_shuffled_2.tolist() == filtered_num, bool_to_check, msg=msg)
def test_caching(self):
# other datasets are too large to be tested
dataset_name = "ogbg-molfreesolv"
# Setup the featurization
featurization_args = {}
featurization_args["atom_property_list_float"] = [] # ["weight", "valence"]
featurization_args["atom_property_list_onehot"] = ["atomic-number", "degree"]
# featurization_args["conformer_property_list"] = ["positions_3d"]
featurization_args["edge_property_list"] = ["bond-type-onehot"]
featurization_args["add_self_loop"] = False
featurization_args["use_bonds_weights"] = False
featurization_args["explicit_H"] = False
# Config for datamodule
task_specific_args = {}
task_specific_args["task_1"] = {"task_level": "graph", "dataset_name": dataset_name}
dm_args = {}
dm_args["featurization"] = featurization_args
dm_args["batch_size_training"] = 16
dm_args["batch_size_inference"] = 16
dm_args["num_workers"] = 0
dm_args["pin_memory"] = True
dm_args["featurization_n_jobs"] = 2
dm_args["featurization_progress"] = True
dm_args["featurization_backend"] = "loky"
dm_args["featurization_batch_size"] = 50
# Delete the cache if already exist
if exists(TEMP_CACHE_DATA_PATH):
rm(TEMP_CACHE_DATA_PATH, recursive=True)
# Prepare the data. It should create the cache there
assert not exists(TEMP_CACHE_DATA_PATH)
ds = GraphOGBDataModule(task_specific_args, cache_data_path=TEMP_CACHE_DATA_PATH, **dm_args)
assert not ds.load_data_from_cache(verbose=False)
ds.prepare_data()
# Check the keys in the dataset
ds.setup(save_smiles_and_ids=False)
assert set(ds.train_ds[0].keys()) == {"features", "labels"}
ds.setup(save_smiles_and_ids=True)
assert set(ds.train_ds[0].keys()) == {"smiles", "mol_ids", "features", "labels"}
# Make sure that the cache is created
full_cache_path = ds.get_data_cache_fullname(compress=False)
assert exists(full_cache_path)
assert get_size(full_cache_path) > 10000
# Check that the data is loaded correctly from cache
assert ds.load_data_from_cache(verbose=False)
# test module
assert ds.num_edge_feats == 5
assert ds.num_node_feats == 50
assert len(ds) == 642
# test batch loader
batch = next(iter(ds.train_dataloader()))
assert len(batch["smiles"]) == 16
assert len(batch["labels"]["graph_task_1"]) == 16
assert len(batch["mol_ids"]) == 16
def test_datamodule_with_none_molecules(self):
# Setup the featurization
featurization_args = {}
featurization_args["atom_property_list_float"] = [] # ["weight", "valence"]
featurization_args["atom_property_list_onehot"] = ["atomic-number", "degree"]
featurization_args["edge_property_list"] = ["bond-type-onehot"]
# Config for datamodule
bad_csv = "tests/data/micro_ZINC_corrupt.csv"
task_specific_args = {}
task_kwargs = {"df_path": bad_csv, "split_val": 0.0, "split_test": 0.0}
task_specific_args["task_1"] = {
"task_level": "graph",
"label_cols": "SA",
"smiles_col": "SMILES1",
**task_kwargs,
}
task_specific_args["task_2"] = {
"task_level": "graph",
"label_cols": "logp",
"smiles_col": "SMILES2",
**task_kwargs,
}
task_specific_args["task_3"] = {
"task_level": "graph",
"label_cols": "score",
"smiles_col": "SMILES3",
**task_kwargs,
}
# Read the corrupted dataset and get stats
df = pd.read_csv(bad_csv)
bad_smiles = (df["SMILES1"] == "XXX") & (df["SMILES2"] == "XXX") & (df["SMILES3"] == "XXX")
num_bad_smiles = sum(bad_smiles)
# Test the datamodule
datamodule = MultitaskFromSmilesDataModule(
task_specific_args=task_specific_args,
featurization_args=featurization_args,
featurization_n_jobs=0,
featurization_batch_size=1,
)
datamodule.prepare_data()
datamodule.setup(save_smiles_and_ids=True)
# Check that the number of molecules is correct
smiles = df["SMILES1"].tolist() + df["SMILES2"].tolist() + df["SMILES3"].tolist()
num_unique_smiles = len(set(smiles)) - 1 # -1 because of the XXX
# self.assertEqual(len(datamodule.train_ds), num_unique_smiles - num_bad_smiles)
# Change the index of the dataframe
index_smiles = []
for ii in range(len(df)):
if df["SMILES1"][ii] != "XXX":
smiles = df["SMILES1"][ii]
elif df["SMILES2"][ii] != "XXX":
smiles = df["SMILES2"][ii]
elif df["SMILES3"][ii] != "XXX":
smiles = df["SMILES3"][ii]
else:
smiles = "XXX"
index_smiles.append(smiles)
df["idx_smiles"] = index_smiles
df = df.set_index("idx_smiles")
# Convert the smilies from the train_ds to a list, and check the content
train_smiles = [d["smiles"] for d in datamodule.train_ds]
# Check that the set of smiles are the same
train_smiles_flat = list(set([item for sublist in train_smiles for item in sublist]))
train_smiles_flat.sort()
index_smiles_filt = list(set([smiles for smiles in index_smiles if smiles != "XXX"]))
index_smiles_filt.sort()
self.assertListEqual(train_smiles_flat, index_smiles_filt)
# Check that the smiles are correct for each datapoint in the dataset
for smiles in train_smiles:
self.assertEqual(len(set(smiles)), 1) # Check that all smiles are the same
this_smiles = smiles[0]
true_smiles = df.loc[this_smiles][["SMILES1", "SMILES2", "SMILES3"]]
num_true_smiles = sum(true_smiles != "XXX")
self.assertEqual(len(smiles), num_true_smiles) # Check that the number of smiles is correct
self.assertEqual(
this_smiles, true_smiles[true_smiles != "XXX"].values[0]
) # Check that the smiles are correct
# Convert the labels from the train_ds to a dataframe
train_labels = [{task: val[0] for task, val in d["labels"].items()} for d in datamodule.train_ds]
train_labels_df = pd.DataFrame(train_labels)
train_labels_df = train_labels_df.rename(
columns={"graph_task_1": "graph_SA", "graph_task_2": "graph_logp", "graph_task_3": "graph_score"}
)
train_labels_df["smiles"] = [s[0] for s in datamodule.train_ds.smiles]
train_labels_df = train_labels_df.set_index("smiles")
train_labels_df = train_labels_df.sort_index()
# Check that the labels are correct
df2 = df.reset_index()[~bad_smiles].set_index("idx_smiles").sort_index()
labels = train_labels_df[["graph_SA", "graph_logp", "graph_score"]].values
nans = np.isnan(labels)
true_nans = df2[["SMILES1", "SMILES2", "SMILES3"]].values == "XXX"
true_labels = df2[["SA", "logp", "score"]].values
true_labels[true_nans] = np.nan
np.testing.assert_array_equal(nans, true_nans) # Check that the nans are correct
np.testing.assert_array_almost_equal(
labels, true_labels, decimal=5
) # Check that the label values are correct
def test_datamodule_multiple_data_files(self):
# Test single CSV files
csv_file = "tests/data/micro_ZINC_shard_1.csv"
task_kwargs = {"df_path": csv_file, "split_val": 0.0, "split_test": 0.0}
task_specific_args = {
"task": {"task_level": "graph", "label_cols": ["score"], "smiles_col": "SMILES", **task_kwargs}
}
ds = MultitaskFromSmilesDataModule(task_specific_args)
ds.prepare_data()
ds.setup()
self.assertEqual(len(ds.train_ds), 10)
# Test multi CSV files
csv_file = "tests/data/micro_ZINC_shard_*.csv"
task_kwargs = {"df_path": csv_file, "split_val": 0.0, "split_test": 0.0}
task_specific_args = {
"task": {"task_level": "graph", "label_cols": ["score"], "smiles_col": "SMILES", **task_kwargs}
}
ds = MultitaskFromSmilesDataModule(task_specific_args)
ds.prepare_data()
ds.setup()
self.assertEqual(len(ds.train_ds), 20)
# Test single Parquet files
parquet_file = "tests/data/micro_ZINC_shard_1.parquet"
task_kwargs = {"df_path": parquet_file, "split_val": 0.0, "split_test": 0.0}
task_specific_args = {
"task": {"task_level": "graph", "label_cols": ["score"], "smiles_col": "SMILES", **task_kwargs}
}
ds = MultitaskFromSmilesDataModule(task_specific_args)
ds.prepare_data()
ds.setup()
self.assertEqual(len(ds.train_ds), 10)
# Test multi Parquet files
parquet_file = "tests/data/micro_ZINC_shard_*.parquet"
task_kwargs = {"df_path": parquet_file, "split_val": 0.0, "split_test": 0.0}
task_specific_args = {
"task": {"task_level": "graph", "label_cols": ["score"], "smiles_col": "SMILES", **task_kwargs}
}
ds = MultitaskFromSmilesDataModule(task_specific_args)
ds.prepare_data()
ds.setup()
self.assertEqual(len(ds.train_ds), 20)
if __name__ == "__main__":
ut.main()
# Delete the cache
if exists(TEMP_CACHE_DATA_PATH):
rm(TEMP_CACHE_DATA_PATH, recursive=True)