Skip to content

Commit

Permalink
[Data] AsGraphPredDataset (dmlc#4073)
Browse files Browse the repository at this point in the history
* Update

* CI

* Update

* Update

* Fix

* Fix
  • Loading branch information
mufeili authored Jun 2, 2022
1 parent 9922f41 commit d9c2552
Show file tree
Hide file tree
Showing 16 changed files with 483 additions and 125 deletions.
1 change: 1 addition & 0 deletions docs/source/api/python/dgl.data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ Dataset adapters

AsNodePredDataset
AsLinkPredDataset
AsGraphPredDataset

Utilities
-----------------
Expand Down
2 changes: 2 additions & 0 deletions docs/source/guide/data-loadcsv.rst
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,8 @@ After loaded, the dataset has multiple homographs with features and labels:
>>> print(data1)
{'feat': tensor([0.5348, 0.2864, 0.1155], dtype=torch.float64), 'label': tensor(0)}
If there is a single feature column in ``graphs.csv``, ``data0`` will directly be a tensor for the feature.


Custom Data Parser
~~~~~~~~~~~~~~~~~~
Expand Down
72 changes: 36 additions & 36 deletions docs/source/guide/data-process.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,41 +29,41 @@ the data.

Take :class:`~dgl.data.QM7bDataset` as example:

.. code::
.. code::
from dgl.data import DGLDataset
class QM7bDataset(DGLDataset):
_url = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/' \
'datasets/qm7b.mat'
_sha1_str = '4102c744bb9d6fd7b40ac67a300e49cd87e28392'
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
super(QM7bDataset, self).__init__(name='qm7b',
url=self._url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
def process(self):
mat_path = self.raw_path + '.mat'
# process data to a list of graphs and a list of labels
self.graphs, self.label = self._load_graph(mat_path)
def __getitem__(self, idx):
""" Get graph and label by index
Parameters
----------
idx : int
Item index
Returns
-------
(dgl.DGLGraph, Tensor)
"""
return self.graphs[idx], self.label[idx]
def __len__(self):
"""Number of graphs in the dataset"""
return len(self.graphs)
Expand All @@ -78,33 +78,33 @@ for details of ``self._load_graph()`` and ``__getitem__``.

One can also add properties to the class to indicate some useful
information of the dataset. In :class:`~dgl.data.QM7bDataset`, one can add a property
``num_labels`` to indicate the total number of prediction tasks in this
``num_tasks`` to indicate the total number of prediction tasks in this
multi-task dataset:

.. code::
.. code::
@property
def num_labels(self):
def num_tasks(self):
"""Number of labels for each graph, i.e. number of prediction tasks."""
return 14
After all these coding, one can finally use :class:`~dgl.data.QM7bDataset` as
follows:

.. code::
.. code::
import dgl
import torch
from dgl.dataloading import GraphDataLoader
# load data
dataset = QM7bDataset()
num_labels = dataset.num_labels
num_tasks = dataset.num_tasks
# create dataloaders
dataloader = GraphDataLoader(dataset, batch_size=1, shuffle=True)
# training
for epoch in range(100):
for g, labels in dataloader:
Expand All @@ -115,7 +115,7 @@ A complete guide for training graph classification models can be found
in :ref:`guide-training-graph-classification`.

For more examples of graph classification datasets, please refer to DGL's builtin graph classification
datasets:
datasets:

* :ref:`gindataset`

Expand All @@ -140,18 +140,18 @@ computation and analysis conducted on the graph. DGL provides an API called
:func:`dgl.reorder_graph` for this purpose. Please refer to ``process()``
part in below example for more details.

.. code::
.. code::
from dgl.data import DGLBuiltinDataset
from dgl.data.utils import _get_dgl_url
class CitationGraphDataset(DGLBuiltinDataset):
_urls = {
'cora_v2' : 'dataset/cora_v2.zip',
'citeseer' : 'dataset/citeseer.zip',
'pubmed' : 'dataset/pubmed.zip',
}
def __init__(self, name, raw_dir=None, force_reload=False, verbose=True):
assert name.lower() in ['cora', 'citeseer', 'pubmed']
if name.lower() == 'cora':
Expand All @@ -162,11 +162,11 @@ part in below example for more details.
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
def process(self):
# Skip some processing code
# === data processing skipped ===
# build graph
g = dgl.graph(graph)
# splitting masks
Expand All @@ -178,15 +178,15 @@ part in below example for more details.
# node features
g.ndata['feat'] = torch.tensor(_preprocess_features(features),
dtype=F.data_type_dict['float32'])
self._num_labels = onehot_labels.shape[1]
self._num_tasks = onehot_labels.shape[1]
self._labels = labels
# reorder graph to obtain better locality.
self._g = dgl.reorder_graph(g)
def __getitem__(self, idx):
assert idx == 0, "This dataset has only one graph"
return self._g
def __len__(self):
return 1
Expand All @@ -204,20 +204,20 @@ and TensorFlow, and ``float tensors`` in MXNet.
The section uses a subclass of ``CitationGraphDataset``, :class:`dgl.data.CiteseerGraphDataset`,
to show the usage of it:

.. code::
.. code::
# load data
dataset = CiteseerGraphDataset(raw_dir='')
graph = dataset[0]
# get split masks
train_mask = graph.ndata['train_mask']
val_mask = graph.ndata['val_mask']
test_mask = graph.ndata['test_mask']
# get node features
feats = graph.ndata['feat']
# get labels
labels = graph.ndata['label']
Expand Down Expand Up @@ -258,7 +258,7 @@ The section uses builtin dataset
as an example, and still skips the detailed data processing code to
highlight the key part for processing link prediction datasets:

.. code::
.. code::
# Example for creating Link Prediction datasets
class KnowledgeGraphDataset(DGLBuiltinDataset):
Expand All @@ -271,11 +271,11 @@ highlight the key part for processing link prediction datasets:
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
def process(self):
# Skip some processing code
# === data processing skipped ===
# splitting mask
g.edata['train_mask'] = train_mask
g.edata['val_mask'] = val_mask
Expand All @@ -285,11 +285,11 @@ highlight the key part for processing link prediction datasets:
# node type
g.ndata['ntype'] = ntype
self._g = g
def __getitem__(self, idx):
assert idx == 0, "This dataset has only one graph"
return self._g
def __len__(self):
return 1
Expand All @@ -299,14 +299,14 @@ code <https://docs.dgl.ai/en/0.5.x/_modules/dgl/data/knowledge_graph.html#Knowle
to see the complete code. The following code uses a subclass of ``KnowledgeGraphDataset``,
:class:`dgl.data.FB15k237Dataset`, to show the usage of it:

.. code::
.. code::
from dgl.data import FB15k237Dataset
# load data
dataset = FB15k237Dataset()
graph = dataset[0]
# get training mask
train_mask = graph.edata['train_mask']
train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
Expand All @@ -319,7 +319,7 @@ A complete guide for training link prediction models can be found in
:ref:`guide-training-link-prediction`.

For more examples of link prediction datasets, please refer to DGL's
builtin datasets:
builtin datasets:

* :ref:`kgdata`

Expand Down
Loading

0 comments on commit d9c2552

Please sign in to comment.