Skip to content

Commit

Permalink
[Sampling] Cluster-GCN and ShaDow-GNN DataLoader (dmlc#3487)
Browse files Browse the repository at this point in the history
* first commit

* next commit

* third commit

* add ShaDow-GNN sampler and unit tests

* fixes

* lint

* cr*p

* lint

* fix lint

* fixes and more unit tests

* more tests

* fix docs

* lint

* fix

* fix

* fix

* fixes

* fix doc
  • Loading branch information
BarclayII authored Nov 16, 2021
1 parent bba88cd commit b8ce0f4
Show file tree
Hide file tree
Showing 17 changed files with 719 additions and 267 deletions.
27 changes: 25 additions & 2 deletions docs/source/api/python/dgl.dataloading.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ and an ``EdgeDataLoader`` for edge/link prediction task.
.. _api-dataloading-neighbor-sampling:

Neighbor Sampler
-----------------------------
----------------
.. currentmodule:: dgl.dataloading.neighbor

Neighbor samplers are classes that control the behavior of ``DataLoader`` s
Expand All @@ -30,7 +30,7 @@ different neighbor sampling strategies by overriding the ``sample_frontier`` or
the ``sample_blocks`` methods.

.. autoclass:: BlockSampler
:members: sample_frontier, sample_blocks
:members: sample_frontier, sample_blocks, sample

.. autoclass:: MultiLayerNeighborSampler
:members: sample_frontier
Expand All @@ -39,6 +39,29 @@ the ``sample_blocks`` methods.
.. autoclass:: MultiLayerFullNeighborSampler
:show-inheritance:

Subgraph Iterators
------------------
Subgraph iterators iterate over the original graph in subgraphs. One should use subgraph
iterators with ``GraphDataLoader`` like follows:

.. code:: python
sgiter = dgl.dataloading.ClusterGCNSubgraphIterator(
g, num_partitions=100, cache_directory='.', refresh=True)
dataloader = dgl.dataloading.GraphDataLoader(sgiter, batch_size=4, num_workers=0)
for subgraph_batch in dataloader:
train_on(subgraph_batch)
.. autoclass:: dgl.dataloading.dataloader.SubgraphIterator

.. autoclass:: dgl.dataloading.cluster_gcn.ClusterGCNSubgraphIterator

ShaDow-GNN Subgraph Sampler
---------------------------
.. currentmodule:: dgl.dataloading.shadow

.. autoclass:: ShaDowKHopSampler

.. _api-dataloading-collators:

Collators
Expand Down
18 changes: 15 additions & 3 deletions examples/pytorch/cluster_gcn/cluster_gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
from dgl.data import register_data_args

from modules import GraphSAGE
Expand Down Expand Up @@ -72,8 +73,17 @@ def main(args):
# metis only support int64 graph
g = g.long()

cluster_iterator = ClusterIter(
args.dataset, g, args.psize, args.batch_size, train_nid, use_pp=args.use_pp)
if args.use_pp:
g.update_all(fn.copy_u('feat', 'm'), fn.sum('m', 'feat_agg'))
g.ndata['feat'] = torch.cat([g.ndata['feat'], g.ndata['feat_agg']], 1)
del g.ndata['feat_agg']

cluster_iterator = dgl.dataloading.GraphDataLoader(
dgl.dataloading.ClusterGCNSubgraphIterator(
dgl.node_subgraph(g, train_nid), args.psize, './cache'),
batch_size=args.batch_size, num_workers=4)
#cluster_iterator = ClusterIter(
# args.dataset, g, args.psize, args.batch_size, train_nid, use_pp=args.use_pp)

# set device for dataset tensors
if args.gpu < 0:
Expand Down Expand Up @@ -132,9 +142,11 @@ def main(args):
cluster = cluster.to(torch.cuda.current_device())
model.train()
# forward
pred = model(cluster)
batch_labels = cluster.ndata['label']
batch_train_mask = cluster.ndata['train_mask']
if batch_train_mask.sum().item() == 0:
continue
pred = model(cluster)
loss = loss_f(pred[batch_train_mask],
batch_labels[batch_train_mask])

Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/cluster_gcn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def reset_parameters(self):

def forward(self, g, h):
g = g.local_var()
if not self.use_pp or not self.training:
if not self.use_pp:
norm = self.get_norm(g)
g.ndata['h'] = h
g.update_all(fn.copy_src(src='h', out='m'),
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/lda/lda_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def perplexity(self, G, doc_data=None):

def doc_subgraph(G, doc_ids):
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
block, *_ = sampler.sample_blocks(G.reverse(), {'doc': torch.as_tensor(doc_ids)})
_, _, (block,) = sampler.sample(G.reverse(), {'doc': torch.as_tensor(doc_ids)})
B = dgl.DGLHeteroGraph(
block._graph, ['_', 'word', 'doc', '_'], block.etypes
).reverse()
Expand Down
4 changes: 2 additions & 2 deletions examples/pytorch/tgn/dataloading.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import dgl

from dgl.dataloading.dataloader import EdgeCollator, assign_block_eids
from dgl.dataloading.dataloader import EdgeCollator
from dgl.dataloading import BlockSampler
from dgl.dataloading.pytorch import _pop_subgraph_storage, _pop_blocks_storage
from dgl.base import DGLError
Expand Down Expand Up @@ -91,7 +91,7 @@ def sample_blocks(self,
#block = transform.to_block(frontier,seed_nodes)
block = frontier
if self.return_eids:
assign_block_eids(block, frontier)
self.assign_block_eids(block, frontier)
blocks.append(block)
return blocks

Expand Down
2 changes: 2 additions & 0 deletions python/dgl/dataloading/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
"""
from .neighbor import *
from .dataloader import *
from .cluster_gcn import *
from .shadow import *

from . import negative_sampler
from .async_transferer import AsyncTransferer
Expand Down
83 changes: 83 additions & 0 deletions python/dgl/dataloading/cluster_gcn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""Cluster-GCN subgraph iterators."""
import os
import pickle
import numpy as np

from ..transform import metis_partition_assignment
from .. import backend as F
from .dataloader import SubgraphIterator

class ClusterGCNSubgraphIterator(SubgraphIterator):
"""Subgraph sampler following that of ClusterGCN.
This sampler first partitions the graph with METIS partitioning, then it caches the nodes of
each partition to a file within the given cache directory.
This is used in conjunction with :class:`dgl.dataloading.pytorch.GraphDataLoader`.
Notes
-----
The graph must be homogeneous and on CPU.
Parameters
----------
g : DGLGraph
The original graph.
num_partitions : int
The number of partitions.
cache_directory : str
The path to the cache directory for storing the partition result.
refresh : bool
If True, recompute the partition.
Examples
--------
Assuming that you have a graph ``g``:
>>> sgiter = dgl.dataloading.ClusterGCNSubgraphIterator(
... g, num_partitions=100, cache_directory='.', refresh=True)
>>> dataloader = dgl.dataloading.GraphDataLoader(sgiter, batch_size=4, num_workers=0)
>>> for subgraph_batch in dataloader:
... train_on(subgraph_batch)
"""
def __init__(self, g, num_partitions, cache_directory, refresh=False):
if os.name == 'nt':
raise NotImplementedError("METIS partitioning is not supported on Windows yet.")
super().__init__(g)

# First see if the cache is already there. If so, directly read from cache.
if not refresh and self._load_parts(cache_directory):
return

# Otherwise, build the cache.
assignment = F.asnumpy(metis_partition_assignment(g, num_partitions))
self._save_parts(assignment, cache_directory)

def _cache_file_path(self, cache_directory):
return os.path.join(cache_directory, 'cluster_gcn_cache')

def _load_parts(self, cache_directory):
path = self._cache_file_path(cache_directory)
if not os.path.exists(path):
return False

with open(path, 'rb') as file_:
self.part_indptr, self.part_indices = pickle.load(file_)
return True

def _save_parts(self, assignment, cache_directory):
os.makedirs(cache_directory, exist_ok=True)

self.part_indices = np.argsort(assignment)
num_nodes_per_part = np.bincount(assignment)
self.part_indptr = np.insert(np.cumsum(num_nodes_per_part), 0, 0)

with open(self._cache_file_path(cache_directory), 'wb') as file_:
pickle.dump((self.part_indptr, self.part_indices), file_)

def __len__(self):
return self.part_indptr.shape[0] - 1

def __getitem__(self, i):
nodes = self.part_indices[self.part_indptr[i]:self.part_indptr[i+1]]
return self.g.subgraph(nodes)
Loading

0 comments on commit b8ce0f4

Please sign in to comment.