Skip to content

Commit

Permalink
[Sampler] [Example] SAINTSampler and Simplify GraphSAINT Example (dml…
Browse files Browse the repository at this point in the history
…c#3879)

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Fix

Co-authored-by: Quan (Andy) Gan <[email protected]>
  • Loading branch information
mufeili and BarclayII authored Mar 30, 2022
1 parent 0dd500e commit 9fee20b
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 2 deletions.
8 changes: 8 additions & 0 deletions docs/source/api/python/dgl.dataloading.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ Samplers
MultiLayerFullNeighborSampler
ClusterGCNSampler
ShaDowKHopSampler
SAINTSampler

Sampler Transformations
-----------------------

.. autosummary::
:toctree: ../../generated/

as_edge_prediction_sampler
BlockSampler

Expand Down
6 changes: 4 additions & 2 deletions examples/pytorch/graphsaint/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ Author's code: https://github.com/GraphSAINT/GraphSAINT

Contributor: Jiahang Li ([@ljh1064126026](https://github.com/ljh1064126026)) Tang Liu ([@lt610](https://github.com/lt610))

For built-in GraphSAINT subgraph samplers with online sampling, use `dgl.dataloading.SAINTSampler`.

## Dependencies

- Python 3.7.10
Expand Down Expand Up @@ -69,7 +71,7 @@ python train_sampling.py --task $task $online

* Paper: results from the paper
* Running: results from experiments with the authors' code
* DGL: results from experiments with the DGL example. The experiment config comes from `config.py`. You can modify parameters in the `config.py` to see different performance of different setup.
* DGL: results from experiments with the DGL example. The experiment config comes from `config.py`. You can modify parameters in the `config.py` to see different performance of different setup.

> Note that we implement offline sampling and online sampling in training phase. Offline sampling means all subgraphs utilized in training phase come from pre-sampled subgraphs. Online sampling means we discard all pre-sampled subgraphs and re-sample new subgraphs in training phase.
Expand Down Expand Up @@ -132,7 +134,7 @@ python train_sampling.py --task $task $online

- We've run experiments 10 times repeatedly to test average and standard deviation of sampling and normalization time. Here we just test time without training model to the end. Moreover, for efficient testing, the hardware and config employed here are not the same as the experiments above, so the sampling time might be a bit different from that above. But we keep the environment consistent in all experiments below.

> The config here which is different with that in the section above is only `num_workers_sampler`, `batch_size_sampler` and `num_workers`, which are only correlated to the sampling speed. Other parameters are kept consistent across two sections thus the model's performance is not affected.
> The config here which is different with that in the section above is only `num_workers_sampler`, `batch_size_sampler` and `num_workers`, which are only correlated to the sampling speed. Other parameters are kept consistent across two sections thus the model's performance is not affected.
> The value is (average, std).
Expand Down
1 change: 1 addition & 0 deletions python/dgl/dataloading/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .. import backend as F
from .neighbor_sampler import *
from .cluster_gcn import *
from .graphsaint import *
from .shadow import *
from .base import *
from . import negative_sampler
Expand Down
146 changes: 146 additions & 0 deletions python/dgl/dataloading/graphsaint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""GraphSAINT samplers."""
from ..base import DGLError
from ..random import choice
from ..sampling import random_walk, pack_traces
from .base import set_node_lazy_features, set_edge_lazy_features, Sampler

try:
import torch
except ImportError:
pass

class SAINTSampler(Sampler):
"""Random node/edge/walk sampler from
`GraphSAINT: Graph Sampling Based Inductive Learning Method
<https://arxiv.org/abs/1907.04931>`__
For each call, the sampler samples a node subset and then returns a node induced subgraph.
There are three options for sampling node subsets:
- For :attr:`'node'` sampler, the probability to sample a node is in proportion
to its out-degree.
- The :attr:`'edge'` sampler first samples an edge subset and then use the
end nodes of the edges.
- The :attr:`'walk'` sampler uses the nodes visited by random walks. It uniformly selects
a number of root nodes and then performs a fixed-length random walk from each root node.
Parameters
----------
mode : str
The sampler to use, which can be :attr:`'node'`, :attr:`'edge'`, or :attr:`'walk'`.
budget : int or tuple[int]
Sampler configuration.
- For :attr:`'node'` sampler, budget specifies the number of nodes
in each sampled subgraph.
- For :attr:`'edge'` sampler, budget specifies the number of edges
to sample for inducing a subgraph.
- For :attr:`'walk'` sampler, budget is a tuple. budget[0] specifies
the number of root nodes to generate random walks. budget[1] specifies
the length of a random walk.
cache : bool, optional
If False, it will not cache the probability arrays for sampling. Setting
it to False is required if you want to use the sampler across different graphs.
prefetch_ndata : list[str], optional
The node data to prefetch for the subgraph.
See :ref:`guide-minibatch-prefetching` for a detailed explanation of prefetching.
prefetch_edata : list[str], optional
The edge data to prefetch for the subgraph.
See :ref:`guide-minibatch-prefetching` for a detailed explanation of prefetching.
output_device : device, optional
The device of the output subgraphs.
Examples
--------
>>> import torch
>>> from dgl.dataloading import SAINTSampler, DataLoader
>>> num_iters = 1000
>>> sampler = SAINTSampler(mode='node', budget=6000)
>>> # Assume g.ndata['feat'] and g.ndata['label'] hold node features and labels
>>> dataloader = DataLoader(g, torch.arange(num_iters), sampler, num_workers=4,
... prefetch_ndata=['feat', 'label'])
>>> for subg in dataloader:
... train_on(subg)
"""
def __init__(self, mode, budget, cache=True, prefetch_ndata=None,
prefetch_edata=None, output_device='cpu'):
super().__init__()
self.budget = budget
if mode == 'node':
self.sampler = self.node_sampler
elif mode == 'edge':
self.sampler = self.edge_sampler
elif mode == 'walk':
self.sampler = self.walk_sampler
else:
raise DGLError(f"Expect mode to be 'node', 'edge' or 'walk', got {mode}.")

self.cache = cache
self.prob = None
self.prefetch_ndata = prefetch_ndata or []
self.prefetch_edata = prefetch_edata or []
self.output_device = output_device

def node_sampler(self, g):
"""Node ID sampler for random node sampler"""
# Alternatively, this can be realized by uniformly sampling an edge subset,
# and then take the src node of the sampled edges. However, the number of edges
# is typically much larger than the number of nodes.
if self.cache and self.prob is not None:
prob = self.prob
else:
prob = g.out_degrees().float().clamp(min=1)
if self.cache:
self.prob = prob
return torch.multinomial(prob, num_samples=self.budget,
replacement=True).unique().type(g.idtype)

def edge_sampler(self, g):
"""Node ID sampler for random edge sampler"""
src, dst = g.edges()
if self.cache and self.prob is not None:
prob = self.prob
else:
in_deg = g.in_degrees().float().clamp(min=1)
out_deg = g.out_degrees().float().clamp(min=1)
# We can reduce the sample space by half if graphs are always symmetric.
prob = 1. / in_deg[dst.long()] + 1. / out_deg[src.long()]
prob /= prob.sum()
if self.cache:
self.prob = prob
sampled_edges = torch.unique(choice(len(prob), size=self.budget, prob=prob))
sampled_nodes = torch.cat([src[sampled_edges], dst[sampled_edges]])
return sampled_nodes.unique().type(g.idtype)

def walk_sampler(self, g):
"""Node ID sampler for random walk sampler"""
num_roots, walk_length = self.budget
sampled_roots = torch.randint(0, g.num_nodes(), (num_roots,))
traces, types = random_walk(g, nodes=sampled_roots, length=walk_length)
sampled_nodes, _, _, _ = pack_traces(traces, types)
return sampled_nodes.unique().type(g.idtype)

def sample(self, g, indices):
"""Sampling function
Parameters
----------
g : DGLGraph
The graph to sample from.
indices : Tensor
Placeholder not used.
Returns
-------
DGLGraph
The sampled subgraph.
"""
node_ids = self.sampler(g)
sg = g.subgraph(node_ids, relabel_nodes=True, output_device=self.output_device)
set_node_lazy_features(sg, self.prefetch_ndata)
set_edge_lazy_features(sg, self.prefetch_edata)
return sg
18 changes: 18 additions & 0 deletions tests/pytorch/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,24 @@ def test_shadow(num_workers):
if i == 5:
break

@pytest.mark.parametrize('num_workers', [0, 4])
@pytest.mark.parametrize('mode', ['node', 'edge', 'walk'])
def test_saint(num_workers, mode):
g = dgl.data.CoraFullDataset()[0]

if mode == 'node':
budget = 100
elif mode == 'edge':
budget = 200
elif mode == 'walk':
budget = (3, 2)

sampler = dgl.dataloading.SAINTSampler(mode, budget)
dataloader = dgl.dataloading.DataLoader(
g, torch.arange(100), sampler, num_workers=num_workers)
assert len(dataloader) == 100
for sg in dataloader:
pass

@pytest.mark.parametrize('num_workers', [0, 4])
def test_neighbor_nonuniform(num_workers):
Expand Down

0 comments on commit 9fee20b

Please sign in to comment.