-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Example] Add HGP-SL example for pytorch backend (dmlc#2515)
* add sagpool example for pytorch backend * polish sagpool example for pytorch backend * [Example] SAGPool: use std variance * [Example] SAGPool: change to std * add sagpool example to index page * add graph property prediction tag to sagpool * [Example] add graph classification example HGP-SL * [Example] fix sagpool * fix bug * [Example] change tab to space in README of hgp-sl * remove redundant files * remote redundant network * [Example]: change link from code to doc in HGP-SL * [Example] in HGP-SL, change to meaningful name * [Example] Fix path mistake for 'hardgat' Co-authored-by: zhangtianqi <[email protected]>
- Loading branch information
Showing
8 changed files
with
1,075 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
# DGL Implementation of the HGP-SL Paper | ||
|
||
This DGL example implements the GNN model proposed in the paper [Hierarchical Graph Pooling with Structure Learning](https://arxiv.org/pdf/1911.05954.pdf). | ||
The author's codes of implementation is in [here](https://github.com/cszhangzhen/HGP-SL) | ||
|
||
|
||
Example implementor | ||
---------------------- | ||
This example was implemented by [Tianqi Zhang](https://github.com/lygztq) during his Applied Scientist Intern work at the AWS Shanghai AI Lab. | ||
|
||
|
||
The graph dataset used in this example | ||
--------------------------------------- | ||
The DGL's built-in [LegacyTUDataset](https://docs.dgl.ai/api/python/dgl.data.html?highlight=tudataset#dgl.data.LegacyTUDataset). This is a serial of graph kernel datasets for graph classification. We use 'DD', 'PROTEINS', 'NCI1', 'NCI109', 'Mutagenicity' and 'ENZYMES' in this HGP-SL implementation. All these datasets are randomly splited to train, validation and test set with ratio 0.8, 0.1 and 0.1. | ||
|
||
NOTE: Since there is no data attributes in some of these datasets, we use node_id (in one-hot vector whose length is the max number of nodes across all graphs) as the node feature. Also note that the node_id in some datasets is not unique (e.g. a graph may has two nodes with the same id). | ||
|
||
DD | ||
- NumGraphs: 1178 | ||
- AvgNodesPerGraph: 284.32 | ||
- AvgEdgesPerGraph: 715.66 | ||
- NumFeats: 89 | ||
- NumClasses: 2 | ||
|
||
PROTEINS | ||
- NumGraphs: 1113 | ||
- AvgNodesPerGraph: 39.06 | ||
- AvgEdgesPerGraph: 72.82 | ||
- NumFeats: 1 | ||
- NumClasses: 2 | ||
|
||
NCI1 | ||
- NumGraphs: 4110 | ||
- AvgNodesPerGraph: 29.87 | ||
- AvgEdgesPerGraph: 32.30 | ||
- NumFeats: 37 | ||
- NumClasses: 2 | ||
|
||
NCI109 | ||
- NumGraphs: 4127 | ||
- AvgNodesPerGraph: 29.68 | ||
- AvgEdgesPerGraph: 32.13 | ||
- NumFeats: 38 | ||
- NumClasses: 2 | ||
|
||
Mutagenicity | ||
- NumGraphs: 4337 | ||
- AvgNodesPerGraph: 30.32 | ||
- AvgEdgesPerGraph: 30.77 | ||
- NumFeats: 14 | ||
- NumClasses: 2 | ||
|
||
ENZYMES | ||
- NumGraphs: 600 | ||
- AvgNodesPerGraph: 32.63 | ||
- AvgEdgesPerGraph: 62.14 | ||
- NumFeats: 18 | ||
- NumClasses: 6 | ||
|
||
How to run example files | ||
-------------------------------- | ||
In the HGP-SL-DGL folder, run | ||
|
||
```bash | ||
python main.py --dataset ${your_dataset_name_here} | ||
``` | ||
|
||
If want to use a GPU, run | ||
|
||
```bash | ||
python main.py --device ${your_device_id_here} --dataset ${your_dataset_name_here} | ||
``` | ||
|
||
Performance | ||
------------------------- | ||
|
||
**Hyper-parameters** | ||
|
||
This part is directly from [author's implementation](https://github.com/cszhangzhen/HGP-SL) | ||
|
||
| Datasets | lr | weight_decay | batch_size | pool_ratio | dropout | net_layers | | ||
| ------------- | --------- | -------------- | --------------- | -------------- | -------- | ---------- | | ||
| PROTEINS | 0.001 | 0.001 | 512 | 0.5 | 0.0 | 3 | | ||
| Mutagenicity | 0.001 | 0.001 | 512 | 0.8 | 0.0 | 3 | | ||
| NCI109 | 0.001 | 0.001 | 512 | 0.8 | 0.0 | 3 | | ||
| NCI1 | 0.001 | 0.001 | 512 | 0.8 | 0.0 | 3 | | ||
| DD | 0.0001 | 0.001 | 64 | 0.3 | 0.5 | 2 | | ||
| ENZYMES | 0.001 | 0.001 | 128 | 0.8 | 0.0 | 2 | | ||
|
||
|
||
**Accuracy** | ||
|
||
**NOTE**: We find that there is a gap between accuracy obtained via author's code and the one reported in the [paper]((https://arxiv.org/pdf/1911.05954.pdf)). An issue has been proposed in the author's repo (see [here](https://github.com/cszhangzhen/HGP-SL/issues/8)). | ||
|
||
| | Mutagenicity | NCI109 | NCI1 | DD | | ||
| -------------------------- | ------------ | ----------- | ----------- | ----------- | | ||
| Reported in Paper | 82.15(0.58) | 80.67(1.16) | 78.45(0.77) | 80.96(1.26) | | ||
| Author's Code (full graph) | 78.44(2.10) | 74.44(2.05) | 77.37(2.09) | OOM | | ||
| Author's Code (sample) | 79.68(1.68) | 73.86(1.72) | 76.29(2.14) | 75.46(3.86) | | ||
| DGL (full graph) | 79.52(2.21) | 74.86(1.99) | 74.62(2.22) | OOM | | ||
| DGL (sample) | 79.15(1.62) | 75.39(1.86) | 73.77(2.04) | 76.47(2.14) | | ||
|
||
|
||
**Speed** | ||
|
||
Device: Tesla V100-SXM2 16GB | ||
|
||
In seconds | ||
|
||
| | DD(batchsize=64), large graph | Mutagenicity(batchsize=512), small graph | | ||
| ----------------------------- | ----------------------------- | ---------------------------------------- | | ||
| Author's code (sample) | 9.96 | 12.91 | | ||
| Author's code (full graph) | OOM | 13.03 | | ||
| DGL (sample) | 9.50 | 3.59 | | ||
| DGL (full graph) | OOM | 3.56 | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
""" | ||
An original implementation of sparsemax (Martins & Astudillo, 2016) is available at | ||
https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/modules/sparse_activations.py. | ||
See `From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification, ICML 2016` | ||
for detailed description. | ||
Here we implement a graph-edge version of sparsemax where we perform sparsemax for all edges | ||
with the same node as end-node in graphs. | ||
""" | ||
import dgl | ||
import torch | ||
from dgl.backend import astype | ||
from dgl.base import ALL, is_all | ||
from dgl.heterograph_index import HeteroGraphIndex | ||
from dgl.sparse import _gsddmm, _gspmm | ||
from torch import Tensor | ||
from torch.autograd import Function | ||
|
||
|
||
def _neighbor_sort(scores:Tensor, end_n_ids:Tensor, in_degrees:Tensor, cum_in_degrees:Tensor): | ||
"""Sort edge scores for each node""" | ||
num_nodes, max_in_degree = in_degrees.size(0), int(in_degrees.max().item()) | ||
|
||
# Compute the index for dense score matrix with size (N x D_{max}) | ||
# Note that the end_n_ids here is the end_node tensor in dgl graph, | ||
# which is not grouped by its node id (i.e. in this form: 0,0,1,1,1,...,N,N). | ||
# Thus here we first sort the end_node tensor to make it easier to compute | ||
# indexs in dense edge score matrix. Since we will need the original order | ||
# for following gspmm and gsddmm operations, we also keep the reverse mapping | ||
# (the reverse_perm) here. | ||
end_n_ids, perm = torch.sort(end_n_ids) | ||
scores = scores[perm] | ||
_, reverse_perm = torch.sort(perm) | ||
|
||
index = torch.arange(end_n_ids.size(0), dtype=torch.long, device=scores.device) | ||
index = (index - cum_in_degrees[end_n_ids]) + (end_n_ids * max_in_degree) | ||
index = index.long() | ||
|
||
dense_scores = scores.new_full((num_nodes * max_in_degree, ), torch.finfo(scores.dtype).min) | ||
dense_scores[index] = scores | ||
dense_scores = dense_scores.view(num_nodes, max_in_degree) | ||
|
||
sorted_dense_scores, dense_reverse_perm = dense_scores.sort(dim=-1, descending=True) | ||
_, dense_reverse_perm = torch.sort(dense_reverse_perm, dim=-1) | ||
dense_reverse_perm = dense_reverse_perm + cum_in_degrees.view(-1, 1) | ||
dense_reverse_perm = dense_reverse_perm.view(-1) | ||
cumsum_sorted_dense_scores = sorted_dense_scores.cumsum(dim=-1).view(-1) | ||
sorted_dense_scores = sorted_dense_scores.view(-1) | ||
arange_vec = torch.arange(1, max_in_degree + 1, dtype=torch.long, device=end_n_ids.device) | ||
arange_vec = torch.repeat_interleave(arange_vec.view(1, -1), num_nodes, dim=0).view(-1) | ||
|
||
valid_mask = (sorted_dense_scores != torch.finfo(scores.dtype).min) | ||
sorted_scores = sorted_dense_scores[valid_mask] | ||
cumsum_sorted_scores = cumsum_sorted_dense_scores[valid_mask] | ||
arange_vec = arange_vec[valid_mask] | ||
dense_reverse_perm = dense_reverse_perm[valid_mask].long() | ||
|
||
return sorted_scores, cumsum_sorted_scores, arange_vec, reverse_perm, dense_reverse_perm | ||
|
||
|
||
def _threshold_and_support_graph(gidx:HeteroGraphIndex, scores:Tensor, end_n_ids:Tensor): | ||
"""Find the threshold for each node and its edges""" | ||
in_degrees = _gspmm(gidx, "copy_rhs", "sum", None, torch.ones_like(scores))[0] | ||
cum_in_degrees = torch.cat([in_degrees.new_zeros(1), in_degrees.cumsum(dim=0)[:-1]], dim=0) | ||
|
||
# perform sort on edges for each node | ||
sorted_scores, cumsum_scores, rhos, reverse_perm, dense_reverse_perm = _neighbor_sort(scores, end_n_ids, | ||
in_degrees, cum_in_degrees) | ||
cumsum_scores = cumsum_scores - 1. | ||
support = rhos * sorted_scores > cumsum_scores | ||
support = support[dense_reverse_perm] # from sorted order to unsorted order | ||
support = support[reverse_perm] # from src-dst order to eid order | ||
|
||
support_size = _gspmm(gidx, "copy_rhs", "sum", None, support.float())[0] | ||
support_size = support_size.long() | ||
idx = support_size + cum_in_degrees - 1 | ||
|
||
# mask invalid index, for example, if batch is not start from 0 or not continuous, it may result in negative index | ||
mask = idx < 0 | ||
idx[mask] = 0 | ||
tau = cumsum_scores.gather(0, idx.long()) | ||
tau /= support_size.to(scores.dtype) | ||
|
||
return tau, support_size | ||
|
||
|
||
class EdgeSparsemaxFunction(Function): | ||
r""" | ||
Description | ||
----------- | ||
Pytorch Auto-Grad Function for edge sparsemax. | ||
We define this auto-grad function here since | ||
sparsemax involves sort and select, which are | ||
not derivative. | ||
""" | ||
@staticmethod | ||
def forward(ctx, gidx:HeteroGraphIndex, scores:Tensor, | ||
eids:Tensor, end_n_ids:Tensor, norm_by:str): | ||
if not is_all(eids): | ||
gidx = gidx.edge_subgraph([eids], True).graph | ||
if norm_by == "src": | ||
gidx = gidx.reverse() | ||
|
||
# use feat - max(feat) for numerical stability. | ||
scores = scores.float() | ||
scores_max = _gspmm(gidx, "copy_rhs", "max", None, scores)[0] | ||
scores = _gsddmm(gidx, "sub", scores, scores_max, "e", "v") | ||
|
||
# find threshold for each node and perform ReLU(u-t(u)) operation. | ||
tau, supp_size = _threshold_and_support_graph(gidx, scores, end_n_ids) | ||
out = torch.clamp(_gsddmm(gidx, "sub", scores, tau, "e", "v"), min=0) | ||
ctx.backward_cache = gidx | ||
ctx.save_for_backward(supp_size, out) | ||
torch.cuda.empty_cache() | ||
return out | ||
|
||
@staticmethod | ||
def backward(ctx, grad_out): | ||
gidx = ctx.backward_cache | ||
supp_size, out = ctx.saved_tensors | ||
grad_in = grad_out.clone() | ||
|
||
# grad for ReLU | ||
grad_in[out == 0] = 0 | ||
|
||
# dL/dv_i = dL/do_i - 1/k \sum_{j=1}^k dL/do_j | ||
v_hat = _gspmm(gidx, "copy_rhs", "sum", None, grad_in)[0] / supp_size.to(out.dtype) | ||
grad_in_modify = _gsddmm(gidx, "sub", grad_in, v_hat, "e", "v") | ||
grad_in = torch.where(out != 0, grad_in_modify, grad_in) | ||
del gidx | ||
torch.cuda.empty_cache() | ||
|
||
return None, grad_in, None, None, None | ||
|
||
|
||
def edge_sparsemax(graph:dgl.DGLGraph, logits, eids=ALL, norm_by="dst"): | ||
r""" | ||
Description | ||
----------- | ||
Compute edge sparsemax. For a node :math:`i`, edge sparsemax is an operation that computes | ||
.. math:: | ||
a_{ij} = \text{ReLU}(z_{ij} - \tau(\z_{i,:})) | ||
where :math:`z_{ij}` is a signal of edge :math:`j\rightarrow i`, also | ||
called logits in the context of sparsemax. :math:`\tau` is a function | ||
that can be found at the `From Softmax to Sparsemax <https://arxiv.org/pdf/1602.02068.pdf>` | ||
paper. | ||
NOTE: currently only homogeneous graphs are supported. | ||
Parameters | ||
---------- | ||
graph : DGLGraph | ||
The graph to perform edge sparsemax on. | ||
logits : torch.Tensor | ||
The input edge feature. | ||
eids : torch.Tensor or ALL, optional | ||
A tensor of edge index on which to apply edge sparsemax. If ALL, apply edge | ||
sparsemax on all edges in the graph. Default: ALL. | ||
norm_by : str, could be 'src' or 'dst' | ||
Normalized by source nodes of destination nodes. Default: `dst`. | ||
Returns | ||
------- | ||
Tensor | ||
Sparsemax value. | ||
""" | ||
# we get edge index tensors here since it is | ||
# hard to get edge index with HeteroGraphIndex | ||
# object without other information like edge_type. | ||
row, col = graph.all_edges(order="eid") | ||
assert norm_by in ["dst", "src"] | ||
end_n_ids = col if norm_by == "dst" else row | ||
if not is_all(eids): | ||
eids = astype(eids, graph.idtype) | ||
end_n_ids = end_n_ids[eids] | ||
return EdgeSparsemaxFunction.apply(graph._graph, logits, | ||
eids, end_n_ids, norm_by) | ||
|
||
|
||
class EdgeSparsemax(torch.nn.Module): | ||
r""" | ||
Description | ||
----------- | ||
Compute edge sparsemax. For a node :math:`i`, edge sparsemax is an operation that computes | ||
.. math:: | ||
a_{ij} = \text{ReLU}(z_{ij} - \tau(\z_{i,:})) | ||
where :math:`z_{ij}` is a signal of edge :math:`j\rightarrow i`, also | ||
called logits in the context of sparsemax. :math:`\tau` is a function | ||
that can be found at the `From Softmax to Sparsemax <https://arxiv.org/pdf/1602.02068.pdf>` | ||
paper. | ||
Parameters | ||
---------- | ||
graph : DGLGraph | ||
The graph to perform edge sparsemax on. | ||
logits : torch.Tensor | ||
The input edge feature. | ||
eids : torch.Tensor or ALL, optional | ||
A tensor of edge index on which to apply edge sparsemax. If ALL, apply edge | ||
sparsemax on all edges in the graph. Default: ALL. | ||
norm_by : str, could be 'src' or 'dst' | ||
Normalized by source nodes of destination nodes. Default: `dst`. | ||
NOTE: currently only homogeneous graphs are supported. | ||
Returns | ||
------- | ||
Tensor | ||
Sparsemax value. | ||
""" | ||
def __init__(self): | ||
super(EdgeSparsemax, self).__init__() | ||
|
||
def forward(self, graph, logits, eids=ALL, norm_by="dst"): | ||
return edge_sparsemax(graph, logits, eids, norm_by) |
Oops, something went wrong.