Skip to content

Commit

Permalink
[Dataset] Add TUDataset (dmlc#473)
Browse files Browse the repository at this point in the history
* add graph classification dataset

* add node label

* add TUDataset

* Modify to consistent with Qi Huang's implementation

* add docs

* Add docs

* Fix change of environment variable

* Update tu.py

* Update tu.py

* Fix error when add node with np.int64
  • Loading branch information
VoVAllen authored and jermainewang committed Apr 10, 2019
1 parent 039a711 commit 00fc680
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 2 deletions.
8 changes: 8 additions & 0 deletions docs/source/api/python/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ Mini graph classification dataset
.. autoclass:: MiniGCDataset
:members: __getitem__, __len__, num_classes

Graph kernel dataset
````````````````````

For more information about the dataset, see `Benchmark Data Sets for Graph Kernels <https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets>`__.

.. autoclass:: TUDataset
:members: __getitem__, __len__

Protein-Protein Interaction dataset
```````````````````````````````````

Expand Down
5 changes: 3 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@
}

# Compatibility for different backend when builds tutorials
if os.environ['DGLBACKEND'] == 'mxnet':
dglbackend = os.environ.get("DGLBACKEND", "")
if dglbackend == 'mxnet':
sphinx_gallery_conf['filename_pattern'] = "/*(?<=mx)\.py"
if os.environ['DGLBACKEND'] == 'pytorch':
if dglbackend == 'pytorch':
sphinx_gallery_conf['filename_pattern'] = "/*(?<!mx)\.py"
1 change: 1 addition & 0 deletions python/dgl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .sbm import SBMMixture
from .reddit import RedditDataset
from .ppi import PPIDataset
from .tu import TUDataset

def register_data_args(parser):
parser.add_argument("--dataset", type=str, required=False,
Expand Down
122 changes: 122 additions & 0 deletions python/dgl/data/tu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from __future__ import absolute_import
import numpy as np
import dgl
import os

from dgl.data.utils import download, extract_archive, get_download_dir


class TUDataset(object):
"""
TUDataset contains lots of graph kernel datasets for graph classification.
Use provided node feature by default. If no feature provided, use one-hot node label instead.
If neither labels provided, use constant for node feature.
:param name: Dataset Name, such as `ENZYMES`, `DD`, `COLLAB`
:param use_pandas: Default: False.
Numpy's file read function has performance issue when file is large,
using pandas can be faster.
:param hidden_size: Default 10. Some dataset doesn't contain features.
Use constant node features initialization instead, with hidden size as `hidden_size`.
"""

_url = r"https://ls11-www.cs.uni-dortmund.de/people/morris/graphkerneldatasets/{}.zip"

def __init__(self, name, use_pandas=False, hidden_size=10):

self.name = name
self.hidden_size = hidden_size
self.extract_dir = self._download()

if use_pandas:
import pandas as pd
DS_edge_list = self._idx_from_zero(
pd.read_csv(self._file_path("A"), delimiter=",", dtype=int).values)
else:
DS_edge_list = self._idx_from_zero(
np.genfromtxt(self._file_path("A"), delimiter=",", dtype=int))

DS_indicator = self._idx_from_zero(
np.genfromtxt(self._file_path("graph_indicator"), dtype=int))
DS_graph_labels = self._idx_from_zero(
np.genfromtxt(self._file_path("graph_labels"), dtype=int))

g = dgl.DGLGraph()
g.add_nodes(int(DS_edge_list.max()) + 1)
g.add_edges(DS_edge_list[:, 0], DS_edge_list[:, 1])
g.add_edges(DS_edge_list[:, 1], DS_edge_list[:, 0])

node_idx_list = []
for idx in range(np.max(DS_indicator) + 1):
node_idx = np.where(DS_indicator == idx)
node_idx_list.append(node_idx[0])
self.graph_lists = g.subgraphs(node_idx_list)
self.graph_labels = DS_graph_labels

try:
DS_node_labels = self._idx_from_zero(
np.loadtxt(self._file_path("node_labels"), dtype=int))
g.ndata['node_label'] = DS_node_labels
one_hot_node_labels = self._to_onehot(DS_node_labels)
for idxs, g in zip(node_idx_list, self.graph_lists):
g.ndata['feat'] = one_hot_node_labels[idxs, :]
except IOError:
print("No Node Label Data")

try:
DS_node_attr = np.loadtxt(self._file_path("node_attributes"), delimiter=",")
for idxs, g in zip(node_idx_list, self.graph_lists):
g.ndata['feat'] = DS_node_attr[idxs, :]
except IOError:
print("No Node Attribute Data")

if 'feat' not in g.ndata.keys():
for idxs, g in zip(node_idx_list, self.graph_lists):
g.ndata['feat'] = np.ones((g.number_of_nodes(), hidden_size))
print("Use Constant one as Feature with hidden size {}".format(hidden_size))

def __getitem__(self, idx):
"""Get the i^th sample.
Paramters
---------
idx : int
The sample index.
Returns
-------
(dgl.DGLGraph, int)
DGLGraph with node feature stored in `feat` field and node label in `node_label` if available.
And its label.
"""
g = self.graph_lists[idx]
return g, self.graph_labels[idx]

def __len__(self):
return len(self.graph_lists)

def _download(self):
download_dir = get_download_dir()
zip_file_path = os.path.join(download_dir, "tu_{}.zip".format(self.name))
download(self._url.format(self.name), path=zip_file_path)
extract_dir = os.path.join(download_dir, "tu_{}".format(self.name))
extract_archive(zip_file_path, extract_dir)
return extract_dir

def _file_path(self, category):
return os.path.join(self.extract_dir, self.name, "{}_{}.txt".format(self.name, category))

@staticmethod
def _idx_from_zero(idx_tensor):
return idx_tensor - np.min(idx_tensor)

@staticmethod
def _to_onehot(label_tensor):
label_num = label_tensor.shape[0]
assert np.min(label_tensor) == 0
one_hot_tensor = np.zeros((label_num, np.max(label_tensor) + 1))
one_hot_tensor[np.arange(label_num), label_tensor] = 1
return one_hot_tensor

def statistics(self):
return self.graph_lists[0].ndata['feat'].shape[1]

0 comments on commit 00fc680

Please sign in to comment.