Skip to content

Commit

Permalink
update dataset links
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Aug 18, 2021
1 parent 665c5da commit fb0f412
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 49 deletions.
39 changes: 21 additions & 18 deletions torch_geometric/datasets/gnn_benchmark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,15 @@ class GNNBenchmarkDataset(InMemoryDataset):

names = ['PATTERN', 'CLUSTER', 'MNIST', 'CIFAR10', 'TSP', 'CSL']

url = 'https://pytorch-geometric.com/datasets/benchmarking-gnns'
csl_url = 'https://www.dropbox.com/s/rnbkp5ubgk82ocu/CSL.zip?dl=1'
root_url = 'https://pytorch-geometric.com/datasets/benchmarking-gnns'
urls = {
'PATTERN': f'{root_url}/PATTERN_v2.zip',
'CLUSTER': f'{root_url}/CLUSTER_v2.zip',
'MNIST': f'{root_url}/MNIST_v2.zip',
'CIFAR10': f'{root_url}/CIFAR10_v2.zip',
'TSP': f'{root_url}/TSP_v2.zip',
'CSL': 'https://www.dropbox.com/s/rnbkp5ubgk82ocu/CSL.zip?dl=1',
}

def __init__(self, root: str, name: str, split: str = "train",
transform: Optional[Callable] = None,
Expand All @@ -59,9 +66,9 @@ def __init__(self, root: str, name: str, split: str = "train",
if self.name == 'CSL' and split != 'train':
split = 'train'
logging.warning(
('Dataset `CSL` does not provide a standardized splitting. '
'Instead, it is recommended to perform 5-fold cross '
'validation with stratifed sampling.'))
("Dataset 'CSL' does not provide a standardized splitting. "
"Instead, it is recommended to perform 5-fold cross "
"validation with stratifed sampling"))

super().__init__(root, transform, pre_transform, pre_filter)

Expand All @@ -87,14 +94,14 @@ def processed_dir(self) -> str:

@property
def raw_file_names(self) -> List[str]:
name = self.name
if name == 'CSL':
if self.name == 'CSL':
return [
'graphs_Kary_Deterministic_Graphs.pkl',
'y_Kary_Deterministic_Graphs.pt'
]
else:
return [f'{name}_train.pt', f'{name}_val.pt', f'{name}_test.pt']
name = self.urls[self.name].split('/')[-1][:-4]
return [f'{name}.pt']

@property
def processed_file_names(self) -> List[str]:
Expand All @@ -104,9 +111,7 @@ def processed_file_names(self) -> List[str]:
return ['train_data.pt', 'val_data.pt', 'test_data.pt']

def download(self):
url = self.csl_url
url = f'{self.url}/{self.name}.zip' if self.name != 'CSL' else url
path = download_url(url, self.raw_dir)
path = download_url(self.urls[self.name], self.raw_dir)
extract_zip(path, self.raw_dir)
os.unlink(path)

Expand All @@ -115,9 +120,9 @@ def process(self):
data_list = self.process_CSL()
torch.save(self.collate(data_list), self.processed_paths[0])
else:
for i in range(3):
self.data, self.slices = torch.load(self.raw_paths[i])
data_list = [self.get(i) for i in range(len(self))]
inputs = torch.load(self.raw_paths[0])
for i in range(len(inputs)):
data_list = [Data(**data_dict) for data_dict in inputs[i]]

if self.pre_filter is not None:
data_list = [d for d in data_list if self.pre_filter(d)]
Expand All @@ -128,12 +133,10 @@ def process(self):
torch.save(self.collate(data_list), self.processed_paths[i])

def process_CSL(self) -> List[Data]:
path = osp.join(self.raw_dir, 'graphs_Kary_Deterministic_Graphs.pkl')
with open(path, 'rb') as f:
with open(self.raw_paths[0], 'rb') as f:
adjs = pickle.load(f)

path = osp.join(self.raw_dir, 'y_Kary_Deterministic_Graphs.pt')
ys = torch.load(path).tolist()
ys = torch.load(self.raw_paths[1]).tolist()

data_list = []
for adj, y in zip(adjs, ys):
Expand Down
36 changes: 10 additions & 26 deletions torch_geometric/datasets/mnist_superpixels.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
from torch_geometric.data import (InMemoryDataset, Data, download_url,
extract_tar)
extract_zip)


class MNISTSuperpixels(InMemoryDataset):
Expand Down Expand Up @@ -30,8 +30,7 @@ class MNISTSuperpixels(InMemoryDataset):
final dataset. (default: :obj:`None`)
"""

url = ('https://graphics.cs.tu-dortmund.de/fileadmin/ls7-www/misc/cvpr/'
'mnist_superpixels.tar.gz')
url = 'https://pytorch-geometric.com/datasets/MNISTSuperpixels.zip'

def __init__(self, root, train=True, transform=None, pre_transform=None,
pre_filter=None):
Expand All @@ -42,41 +41,26 @@ def __init__(self, root, train=True, transform=None, pre_transform=None,

@property
def raw_file_names(self):
return ['training.pt', 'test.pt']
return ['MNISTSuperpixels.pt']

@property
def processed_file_names(self):
return ['training.pt', 'test.pt']
return ['train_data.pt', 'test_data.pt']

def download(self):
path = download_url(self.url, self.raw_dir)
extract_tar(path, self.raw_dir, mode='r')
extract_zip(path, self.raw_dir)
os.unlink(path)

def process(self):
for raw_path, path in zip(self.raw_paths, self.processed_paths):
x, edge_index, edge_slice, pos, y = torch.load(raw_path)
edge_index, y = edge_index.to(torch.long), y.to(torch.long)
m, n = y.size(0), 75
x, pos = x.view(m * n, 1), pos.view(m * n, 2)
node_slice = torch.arange(0, (m + 1) * n, step=n, dtype=torch.long)
graph_slice = torch.arange(m + 1, dtype=torch.long)
self.data = Data(x=x, edge_index=edge_index, y=y, pos=pos)
self.slices = {
'x': node_slice,
'edge_index': edge_slice,
'y': graph_slice,
'pos': node_slice
}
inputs = torch.load(self.raw_paths[0])
for i in range(len(inputs)):
data_list = [Data(**data_dict) for data_dict in inputs[i]]

if self.pre_filter is not None:
data_list = [self.get(idx) for idx in range(len(self))]
data_list = [d for d in data_list if self.pre_filter(d)]
self.data, self.slices = self.collate(data_list)

if self.pre_transform is not None:
data_list = [self.get(idx) for idx in range(len(self))]
data_list = [self.pre_transform(data) for data in data_list]
self.data, self.slices = self.collate(data_list)
data_list = [self.pre_transform(d) for d in data_list]

torch.save((self.data, self.slices), path)
torch.save(self.collate(data_list), self.processed_paths[i])
10 changes: 5 additions & 5 deletions torch_geometric/datasets/qm9.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class QM9(InMemoryDataset):
raw_url = ('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/'
'molnet_publish/qm9.zip')
raw_url2 = 'https://ndownloader.figshare.com/files/3195404'
processed_url = 'https://pytorch-geometric.com/datasets/qm9_v2.zip'
processed_url = 'https://pytorch-geometric.com/datasets/qm9_v3.zip'

def __init__(self, root: str, transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
Expand Down Expand Up @@ -140,11 +140,11 @@ def raw_file_names(self) -> List[str]:
import rdkit # noqa
return ['gdb9.sdf', 'gdb9.sdf.csv', 'uncharacterized.txt']
except ImportError:
return ['qm9_v2.pt']
return ['qm9_v3.pt']

@property
def processed_file_names(self) -> str:
return 'data_v2.pt'
return 'data_v3.pt'

def download(self):
try:
Expand Down Expand Up @@ -178,8 +178,8 @@ def process(self):
"install 'rdkit' to alternatively process the raw data."),
file=sys.stderr)

self.data, self.slices = torch.load(self.raw_paths[0])
data_list = [self.get(i) for i in range(len(self))]
data_list = torch.load(self.raw_paths[0])
data_list = [Data(**data_dict) for data_dict in data_list]

if self.pre_filter is not None:
data_list = [d for d in data_list if self.pre_filter(d)]
Expand Down

0 comments on commit fb0f412

Please sign in to comment.