Skip to content

Commit

Permalink
[Example] Add HGP-SL example for pytorch backend (dmlc#2515)
Browse files Browse the repository at this point in the history
* add sagpool example for pytorch backend

* polish sagpool example for pytorch backend

* [Example] SAGPool: use std variance

* [Example] SAGPool: change to std

* add sagpool example to index page

* add graph property prediction tag to sagpool

* [Example] add graph classification example HGP-SL

* [Example] fix sagpool

* fix bug

* [Example] change tab to space in README of hgp-sl

* remove redundant files

* remote redundant network

* [Example]: change link from code to doc in HGP-SL

* [Example] in HGP-SL, change to meaningful name

* [Example] Fix path mistake for 'hardgat'

Co-authored-by: zhangtianqi <[email protected]>
  • Loading branch information
lygztq and zhangtianqi authored Jan 18, 2021
1 parent 1caf01d commit b36b6c2
Show file tree
Hide file tree
Showing 8 changed files with 1,075 additions and 2 deletions.
7 changes: 6 additions & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ The folder contains example implementations of selected research papers related
| [Self-Attention Graph Pooling](#sagpool) | | | :heavy_check_mark: | | |
| [Convolutional Networks on Graphs for Learning Molecular Fingerprints](#nf) | | | :heavy_check_mark: | | |
| [GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation](#gnnfilm) | :heavy_check_mark: | | | | |
| [Hierarchical Graph Pooling with Structure Learning](#hgp-sl) | | | :heavy_check_mark: | | |
| [Graph Representation Learning via Hard and Channel-Wise Attention Networks](#hardgat) |:heavy_check_mark: | | | | |

## 2020
Expand Down Expand Up @@ -144,8 +145,12 @@ The folder contains example implementations of selected research papers related
- Example code: [PyTorch](../examples/pytorch/sagpool)
- Tags: graph classification, pooling

- <a name="hgp-sl"></a> Zhang, Zhen, et al. Hierarchical Graph Pooling with Structure Learning. [Paper link](https://arxiv.org/abs/1911.05954).
- Example code: [PyTorch](../examples/pytorch/hgp_sl)
- Tags: graph classification, pooling

- <a name='hardgat'></a> Gao, Hongyang, et al. Graph Representation Learning via Hard and Channel-Wise Attention Networks [Paper link](https://arxiv.org/abs/1907.04652).
- Example code: [Pytorch](../examples/pytorch/hgat)
- Example code: [Pytorch](../examples/pytorch/hardgat)
- Tags: node classification, graph attention

## 2018
Expand Down
115 changes: 115 additions & 0 deletions examples/pytorch/hgp_sl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# DGL Implementation of the HGP-SL Paper

This DGL example implements the GNN model proposed in the paper [Hierarchical Graph Pooling with Structure Learning](https://arxiv.org/pdf/1911.05954.pdf).
The author's codes of implementation is in [here](https://github.com/cszhangzhen/HGP-SL)


Example implementor
----------------------
This example was implemented by [Tianqi Zhang](https://github.com/lygztq) during his Applied Scientist Intern work at the AWS Shanghai AI Lab.


The graph dataset used in this example
---------------------------------------
The DGL's built-in [LegacyTUDataset](https://docs.dgl.ai/api/python/dgl.data.html?highlight=tudataset#dgl.data.LegacyTUDataset). This is a serial of graph kernel datasets for graph classification. We use 'DD', 'PROTEINS', 'NCI1', 'NCI109', 'Mutagenicity' and 'ENZYMES' in this HGP-SL implementation. All these datasets are randomly splited to train, validation and test set with ratio 0.8, 0.1 and 0.1.

NOTE: Since there is no data attributes in some of these datasets, we use node_id (in one-hot vector whose length is the max number of nodes across all graphs) as the node feature. Also note that the node_id in some datasets is not unique (e.g. a graph may has two nodes with the same id).

DD
- NumGraphs: 1178
- AvgNodesPerGraph: 284.32
- AvgEdgesPerGraph: 715.66
- NumFeats: 89
- NumClasses: 2

PROTEINS
- NumGraphs: 1113
- AvgNodesPerGraph: 39.06
- AvgEdgesPerGraph: 72.82
- NumFeats: 1
- NumClasses: 2

NCI1
- NumGraphs: 4110
- AvgNodesPerGraph: 29.87
- AvgEdgesPerGraph: 32.30
- NumFeats: 37
- NumClasses: 2

NCI109
- NumGraphs: 4127
- AvgNodesPerGraph: 29.68
- AvgEdgesPerGraph: 32.13
- NumFeats: 38
- NumClasses: 2

Mutagenicity
- NumGraphs: 4337
- AvgNodesPerGraph: 30.32
- AvgEdgesPerGraph: 30.77
- NumFeats: 14
- NumClasses: 2

ENZYMES
- NumGraphs: 600
- AvgNodesPerGraph: 32.63
- AvgEdgesPerGraph: 62.14
- NumFeats: 18
- NumClasses: 6

How to run example files
--------------------------------
In the HGP-SL-DGL folder, run

```bash
python main.py --dataset ${your_dataset_name_here}
```

If want to use a GPU, run

```bash
python main.py --device ${your_device_id_here} --dataset ${your_dataset_name_here}
```

Performance
-------------------------

**Hyper-parameters**

This part is directly from [author's implementation](https://github.com/cszhangzhen/HGP-SL)

| Datasets | lr | weight_decay | batch_size | pool_ratio | dropout | net_layers |
| ------------- | --------- | -------------- | --------------- | -------------- | -------- | ---------- |
| PROTEINS | 0.001 | 0.001 | 512 | 0.5 | 0.0 | 3 |
| Mutagenicity | 0.001 | 0.001 | 512 | 0.8 | 0.0 | 3 |
| NCI109 | 0.001 | 0.001 | 512 | 0.8 | 0.0 | 3 |
| NCI1 | 0.001 | 0.001 | 512 | 0.8 | 0.0 | 3 |
| DD | 0.0001 | 0.001 | 64 | 0.3 | 0.5 | 2 |
| ENZYMES | 0.001 | 0.001 | 128 | 0.8 | 0.0 | 2 |


**Accuracy**

**NOTE**: We find that there is a gap between accuracy obtained via author's code and the one reported in the [paper]((https://arxiv.org/pdf/1911.05954.pdf)). An issue has been proposed in the author's repo (see [here](https://github.com/cszhangzhen/HGP-SL/issues/8)).

| | Mutagenicity | NCI109 | NCI1 | DD |
| -------------------------- | ------------ | ----------- | ----------- | ----------- |
| Reported in Paper | 82.15(0.58) | 80.67(1.16) | 78.45(0.77) | 80.96(1.26) |
| Author's Code (full graph) | 78.44(2.10) | 74.44(2.05) | 77.37(2.09) | OOM |
| Author's Code (sample) | 79.68(1.68) | 73.86(1.72) | 76.29(2.14) | 75.46(3.86) |
| DGL (full graph) | 79.52(2.21) | 74.86(1.99) | 74.62(2.22) | OOM |
| DGL (sample) | 79.15(1.62) | 75.39(1.86) | 73.77(2.04) | 76.47(2.14) |


**Speed**

Device: Tesla V100-SXM2 16GB

In seconds

| | DD(batchsize=64), large graph | Mutagenicity(batchsize=512), small graph |
| ----------------------------- | ----------------------------- | ---------------------------------------- |
| Author's code (sample) | 9.96 | 12.91 |
| Author's code (full graph) | OOM | 13.03 |
| DGL (sample) | 9.50 | 3.59 |
| DGL (full graph) | OOM | 3.56 |
220 changes: 220 additions & 0 deletions examples/pytorch/hgp_sl/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
"""
An original implementation of sparsemax (Martins & Astudillo, 2016) is available at
https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/modules/sparse_activations.py.
See `From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification, ICML 2016`
for detailed description.
Here we implement a graph-edge version of sparsemax where we perform sparsemax for all edges
with the same node as end-node in graphs.
"""
import dgl
import torch
from dgl.backend import astype
from dgl.base import ALL, is_all
from dgl.heterograph_index import HeteroGraphIndex
from dgl.sparse import _gsddmm, _gspmm
from torch import Tensor
from torch.autograd import Function


def _neighbor_sort(scores:Tensor, end_n_ids:Tensor, in_degrees:Tensor, cum_in_degrees:Tensor):
"""Sort edge scores for each node"""
num_nodes, max_in_degree = in_degrees.size(0), int(in_degrees.max().item())

# Compute the index for dense score matrix with size (N x D_{max})
# Note that the end_n_ids here is the end_node tensor in dgl graph,
# which is not grouped by its node id (i.e. in this form: 0,0,1,1,1,...,N,N).
# Thus here we first sort the end_node tensor to make it easier to compute
# indexs in dense edge score matrix. Since we will need the original order
# for following gspmm and gsddmm operations, we also keep the reverse mapping
# (the reverse_perm) here.
end_n_ids, perm = torch.sort(end_n_ids)
scores = scores[perm]
_, reverse_perm = torch.sort(perm)

index = torch.arange(end_n_ids.size(0), dtype=torch.long, device=scores.device)
index = (index - cum_in_degrees[end_n_ids]) + (end_n_ids * max_in_degree)
index = index.long()

dense_scores = scores.new_full((num_nodes * max_in_degree, ), torch.finfo(scores.dtype).min)
dense_scores[index] = scores
dense_scores = dense_scores.view(num_nodes, max_in_degree)

sorted_dense_scores, dense_reverse_perm = dense_scores.sort(dim=-1, descending=True)
_, dense_reverse_perm = torch.sort(dense_reverse_perm, dim=-1)
dense_reverse_perm = dense_reverse_perm + cum_in_degrees.view(-1, 1)
dense_reverse_perm = dense_reverse_perm.view(-1)
cumsum_sorted_dense_scores = sorted_dense_scores.cumsum(dim=-1).view(-1)
sorted_dense_scores = sorted_dense_scores.view(-1)
arange_vec = torch.arange(1, max_in_degree + 1, dtype=torch.long, device=end_n_ids.device)
arange_vec = torch.repeat_interleave(arange_vec.view(1, -1), num_nodes, dim=0).view(-1)

valid_mask = (sorted_dense_scores != torch.finfo(scores.dtype).min)
sorted_scores = sorted_dense_scores[valid_mask]
cumsum_sorted_scores = cumsum_sorted_dense_scores[valid_mask]
arange_vec = arange_vec[valid_mask]
dense_reverse_perm = dense_reverse_perm[valid_mask].long()

return sorted_scores, cumsum_sorted_scores, arange_vec, reverse_perm, dense_reverse_perm


def _threshold_and_support_graph(gidx:HeteroGraphIndex, scores:Tensor, end_n_ids:Tensor):
"""Find the threshold for each node and its edges"""
in_degrees = _gspmm(gidx, "copy_rhs", "sum", None, torch.ones_like(scores))[0]
cum_in_degrees = torch.cat([in_degrees.new_zeros(1), in_degrees.cumsum(dim=0)[:-1]], dim=0)

# perform sort on edges for each node
sorted_scores, cumsum_scores, rhos, reverse_perm, dense_reverse_perm = _neighbor_sort(scores, end_n_ids,
in_degrees, cum_in_degrees)
cumsum_scores = cumsum_scores - 1.
support = rhos * sorted_scores > cumsum_scores
support = support[dense_reverse_perm] # from sorted order to unsorted order
support = support[reverse_perm] # from src-dst order to eid order

support_size = _gspmm(gidx, "copy_rhs", "sum", None, support.float())[0]
support_size = support_size.long()
idx = support_size + cum_in_degrees - 1

# mask invalid index, for example, if batch is not start from 0 or not continuous, it may result in negative index
mask = idx < 0
idx[mask] = 0
tau = cumsum_scores.gather(0, idx.long())
tau /= support_size.to(scores.dtype)

return tau, support_size


class EdgeSparsemaxFunction(Function):
r"""
Description
-----------
Pytorch Auto-Grad Function for edge sparsemax.
We define this auto-grad function here since
sparsemax involves sort and select, which are
not derivative.
"""
@staticmethod
def forward(ctx, gidx:HeteroGraphIndex, scores:Tensor,
eids:Tensor, end_n_ids:Tensor, norm_by:str):
if not is_all(eids):
gidx = gidx.edge_subgraph([eids], True).graph
if norm_by == "src":
gidx = gidx.reverse()

# use feat - max(feat) for numerical stability.
scores = scores.float()
scores_max = _gspmm(gidx, "copy_rhs", "max", None, scores)[0]
scores = _gsddmm(gidx, "sub", scores, scores_max, "e", "v")

# find threshold for each node and perform ReLU(u-t(u)) operation.
tau, supp_size = _threshold_and_support_graph(gidx, scores, end_n_ids)
out = torch.clamp(_gsddmm(gidx, "sub", scores, tau, "e", "v"), min=0)
ctx.backward_cache = gidx
ctx.save_for_backward(supp_size, out)
torch.cuda.empty_cache()
return out

@staticmethod
def backward(ctx, grad_out):
gidx = ctx.backward_cache
supp_size, out = ctx.saved_tensors
grad_in = grad_out.clone()

# grad for ReLU
grad_in[out == 0] = 0

# dL/dv_i = dL/do_i - 1/k \sum_{j=1}^k dL/do_j
v_hat = _gspmm(gidx, "copy_rhs", "sum", None, grad_in)[0] / supp_size.to(out.dtype)
grad_in_modify = _gsddmm(gidx, "sub", grad_in, v_hat, "e", "v")
grad_in = torch.where(out != 0, grad_in_modify, grad_in)
del gidx
torch.cuda.empty_cache()

return None, grad_in, None, None, None


def edge_sparsemax(graph:dgl.DGLGraph, logits, eids=ALL, norm_by="dst"):
r"""
Description
-----------
Compute edge sparsemax. For a node :math:`i`, edge sparsemax is an operation that computes
.. math::
a_{ij} = \text{ReLU}(z_{ij} - \tau(\z_{i,:}))
where :math:`z_{ij}` is a signal of edge :math:`j\rightarrow i`, also
called logits in the context of sparsemax. :math:`\tau` is a function
that can be found at the `From Softmax to Sparsemax <https://arxiv.org/pdf/1602.02068.pdf>`
paper.
NOTE: currently only homogeneous graphs are supported.
Parameters
----------
graph : DGLGraph
The graph to perform edge sparsemax on.
logits : torch.Tensor
The input edge feature.
eids : torch.Tensor or ALL, optional
A tensor of edge index on which to apply edge sparsemax. If ALL, apply edge
sparsemax on all edges in the graph. Default: ALL.
norm_by : str, could be 'src' or 'dst'
Normalized by source nodes of destination nodes. Default: `dst`.
Returns
-------
Tensor
Sparsemax value.
"""
# we get edge index tensors here since it is
# hard to get edge index with HeteroGraphIndex
# object without other information like edge_type.
row, col = graph.all_edges(order="eid")
assert norm_by in ["dst", "src"]
end_n_ids = col if norm_by == "dst" else row
if not is_all(eids):
eids = astype(eids, graph.idtype)
end_n_ids = end_n_ids[eids]
return EdgeSparsemaxFunction.apply(graph._graph, logits,
eids, end_n_ids, norm_by)


class EdgeSparsemax(torch.nn.Module):
r"""
Description
-----------
Compute edge sparsemax. For a node :math:`i`, edge sparsemax is an operation that computes
.. math::
a_{ij} = \text{ReLU}(z_{ij} - \tau(\z_{i,:}))
where :math:`z_{ij}` is a signal of edge :math:`j\rightarrow i`, also
called logits in the context of sparsemax. :math:`\tau` is a function
that can be found at the `From Softmax to Sparsemax <https://arxiv.org/pdf/1602.02068.pdf>`
paper.
Parameters
----------
graph : DGLGraph
The graph to perform edge sparsemax on.
logits : torch.Tensor
The input edge feature.
eids : torch.Tensor or ALL, optional
A tensor of edge index on which to apply edge sparsemax. If ALL, apply edge
sparsemax on all edges in the graph. Default: ALL.
norm_by : str, could be 'src' or 'dst'
Normalized by source nodes of destination nodes. Default: `dst`.
NOTE: currently only homogeneous graphs are supported.
Returns
-------
Tensor
Sparsemax value.
"""
def __init__(self):
super(EdgeSparsemax, self).__init__()

def forward(self, graph, logits, eids=ALL, norm_by="dst"):
return edge_sparsemax(graph, logits, eids, norm_by)
Loading

0 comments on commit b36b6c2

Please sign in to comment.