-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_data.py
44 lines (36 loc) · 1.29 KB
/
test_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import dgl.data as data
import unittest
import backend as F
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_minigc():
ds = data.MiniGCDataset(16, 10, 20)
g, l = list(zip(*ds))
print(g, l)
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_gin():
ds_n_graphs = {
'MUTAG': 188,
'IMDBBINARY': 1000,
'IMDBMULTI': 1500,
'PROTEINS': 1113,
'PTC': 344,
}
for name, n_graphs in ds_n_graphs.items():
ds = data.GINDataset(name, self_loop=False, degree_as_nlabel=False)
assert len(ds) == n_graphs, (len(ds), name)
@unittest.skipIf(F._default_context_str == 'gpu', reason="Datasets don't need to be tested on GPU.")
def test_data_hash():
class HashTestDataset(data.DGLDataset):
def __init__(self, hash_key=()):
super(HashTestDataset, self).__init__('hashtest', hash_key=hash_key)
def _load(self):
pass
a = HashTestDataset((True, 0, '1', (1,2,3)))
b = HashTestDataset((True, 0, '1', (1,2,3)))
c = HashTestDataset((True, 0, '1', (1,2,4)))
assert a.hash == b.hash
assert a.hash != c.hash
if __name__ == '__main__':
test_minigc()
test_gin()
test_data_hash()