Skip to content

Commit

Permalink
[Example] Add sagpool example for pytorch backend (dmlc#2429)
Browse files Browse the repository at this point in the history
* add sagpool example for pytorch backend

* polish sagpool example for pytorch backend

* [Example] SAGPool: use std variance

* [Example] SAGPool: change to std

* add sagpool example to index page

* add graph property prediction tag to sagpool

Co-authored-by: zhangtianqi <[email protected]>
  • Loading branch information
lygztq and zhangtianqi authored Dec 28, 2020
1 parent 5d8330c commit 72ef642
Show file tree
Hide file tree
Showing 9 changed files with 671 additions and 0 deletions.
5 changes: 5 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
| [Molecular Graph Convolutions: Moving Beyond Fingerprints](#weave) | | | :heavy_check_mark: | | |
| [LINE: Large-scale Information Network Embedding](#line) | | :heavy_check_mark: | | | :heavy_check_mark: |
| [DeepWalk: Online Learning of Social Representations](#deepwalk) | | :heavy_check_mark: | | | :heavy_check_mark: |
| [Self-Attention Graph Pooling](#sagpool) | | | :heavy_check_mark: | | |
| | | | | | |
| | | | | | |

Expand Down Expand Up @@ -130,6 +131,10 @@
- Example code: [PyTorch](../examples/pytorch/mixhop)
- Tags: node classification

- <a name="sagpool"></a> Lee, Junhyun, et al. Self-Attention Graph Pooling. [Paper link](https://arxiv.org/abs/1904.08082).
- Example code: [PyTorch](../examples/pytorch/sagpool)
- Tags: graph classification, pooling

## 2018

- <a name="dgmg"></a> Li et al. Learning Deep Generative Models of Graphs. [Paper link](https://arxiv.org/abs/1803.03324).
Expand Down
96 changes: 96 additions & 0 deletions examples/pytorch/sagpool/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# DGL Implementation of the SAGPool Paper

This DGL example implements the GNN model proposed in the paper [Self Attention Graph Pooling](https://arxiv.org/pdf/1904.08082.pdf).
The author's codes of implementation is in [here](https://github.com/inyeoplee77/SAGPool)


The graph dataset used in this example
---------------------------------------
The DGL's built-in LegacyTUDataset. This is a serial of graph kernel datasets for graph classification. We use 'DD', 'PROTEINS', 'NCI1', 'NCI109' and 'Mutagenicity' in this SAGPool implementation. All these datasets are randomly splited to train, validation and test set with ratio 0.8, 0.1 and 0.1.

NOTE: Since there is no data attributes in some of these datasets, we use node_id (in one-hot vector whose length is the max number of nodes across all graphs) as the node feature. Also note that the node_id in some datasets is not unique (e.g. a graph may has two nodes with the same id).

DD
- NumGraphs: 1178
- AvgNodesPerGraph: 284.32
- AvgEdgesPerGraph: 715.66
- NumFeats: 89
- NumClasses: 2

PROTEINS
- NumGraphs: 1113
- AvgNodesPerGraph: 39.06
- AvgEdgesPerGraph: 72.82
- NumFeats: 1
- NumClasses: 2

NCI1
- NumGraphs: 4110
- AvgNodesPerGraph: 29.87
- AvgEdgesPerGraph: 32.30
- NumFeats: 37
- NumClasses: 2

NCI109
- NumGraphs: 4127
- AvgNodesPerGraph: 29.68
- AvgEdgesPerGraph: 32.13
- NumFeats: 38
- NumClasses: 2

Mutagenicity
- NumGraphs: 4337
- AvgNodesPerGraph: 30.32
- AvgEdgesPerGraph: 30.77
- NumFeats: 14
- NumClasses: 2


How to run example files
--------------------------------
The valid dataset names (you can find a full list [here](https://chrsmrrs.github.io/datasets/docs/datasets/)):
- 'DD' for D&D
- 'PROTEINS' for PROTEINS
- 'NCI1' for NCI1
- 'NCI109' for NCI109
- 'Mutagenicity' for Mutagenicity

In the sagpool folder, run

```bash
python main.py --dataset ${your_dataset_name_here}
```

If want to use a GPU, run

```bash
python main.py --device ${your_device_id_here} --dataset ${your_dataset_name_here}
```

If your want to perform a grid search, modify parameter settings in `grid_search_config.json` and run
```bash
python grid_search.py --device ${your_device_id_here} --num_trials ${num_of_trials_here}
```

Performance
-------------------------

NOTE: We do not perform grid search or finetune here, so there may be a gap between results in paper and our results. Also, we only perform 10 trials for each experiment, which is different from 200 trials per experiment in the paper.

**The global architecture result**
| Dataset | paper result (global) | ours (global) |
| ------------- | -------------------------------- | --------------------------- |
| D&D | 76.19 (0.94) | 74.79 (2.69) |
| PROTEINS | 70.04 (1.47) | 70.36 (5.90) |
| NCI1 | 74.18 (1.20) | 72.82 (2.36) |
| NCI109 | 74.06 (0.78) | 71.64 (2.65) |
| Mutagenicity | N/A | 76.55 (2.89) |

**The hierarchical architecture result**
| Dataset | paper result (hierarchical) | ours (hierarchical) |
| ------------- | -------------------------------- | --------------------------- |
| D&D | 76.45 (0.97) | 75.38 (4.17) |
| PROTEINS | 71.86 (0.97) | 70.36 (5.68) |
| NCI1 | 67.45 (1.11) | 70.61 (2.25) |
| NCI109 | 67.86 (1.41) | 69.13 (3.85) |
| Mutagenicity | N/A | 75.20 (1.95) |
29 changes: 29 additions & 0 deletions examples/pytorch/sagpool/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch.utils.data
from torch.utils.data.dataloader import DataLoader
import dgl
import numpy as np


def collate_fn(batch):
"""
collate_fn for dataset batching
transform ndata to tensor (in gpu is available)
"""
graphs, labels = map(list, zip(*batch))

# batch graphs and cast to PyTorch tensor
for graph in graphs:
for (key, value) in graph.ndata.items():
graph.ndata[key] = value.float()
batched_graphs = dgl.batch(graphs)

# cast to PyTorch tensor
batched_labels = torch.LongTensor(np.array(labels))

return batched_graphs, batched_labels


class GraphDataLoader(DataLoader):
def __init__(self, dataset, batch_size=1, shuffle=False, **kwargs):
super(GraphDataLoader, self).__init__(dataset, batch_size, shuffle,
collate_fn=collate_fn, **kwargs)
55 changes: 55 additions & 0 deletions examples/pytorch/sagpool/grid_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import json
import os
from copy import deepcopy
from main import main, parse_args
from utils import get_stats

def load_config(path="./grid_search_config.json"):
with open(path, "r") as f:
return json.load(f)

def run_experiments(args):
res = []
for i in range(args.num_trials):
print("Trial {}/{}".format(i + 1, args.num_trials))
acc, _ = main(args)
res.append(acc)

mean, err_bd = get_stats(res, conf_interval=True)
return mean, err_bd

def grid_search(config:dict):
args = parse_args()
results = {}

for d in config["dataset"]:
args.dataset = d
best_acc, err_bd = 0., 0.
best_args = vars(args)
for arch in config["arch"]:
args.architecture = arch
for hidden in config["hidden"]:
args.hid_dim = hidden
for pool_ratio in config["pool_ratio"]:
args.pool_ratio = pool_ratio
for lr in config["lr"]:
args.lr = lr
for weight_decay in config["weight_decay"]:
args.weight_decay = weight_decay
acc, bd = run_experiments(args)
if acc > best_acc:
best_acc = acc
err_bd = bd
best_args = deepcopy(vars(args))
args.output_path = "./output"
if not os.path.exists(args.output_path):
os.makedirs(args.output_path)
args.output_path = "./output/{}.log".format(d)
result = {
"params": best_args,
"result": "{:.4f}({:.4f})".format(best_acc, err_bd)
}
with open(args.output_path, "w") as f:
json.dump(result, f, sort_keys=True, indent=4)

grid_search(load_config())
8 changes: 8 additions & 0 deletions examples/pytorch/sagpool/grid_search_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"arch": ["hierarchical", "global"],
"hidden": [16, 32, 64, 128],
"pool_ratio": [0.25, 0.5],
"lr": [1e-2, 5e-2, 1e-3, 5e-3, 1e-4, 5e-4],
"weight_decay": [1e-2, 1e-3, 1e-4, 1e-5],
"dataset": ["DD", "PROTEINS", "NCI1", "NCI109", "Mutagenicity"]
}
59 changes: 59 additions & 0 deletions examples/pytorch/sagpool/layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch
import torch.nn.functional as F
import dgl
from dgl.nn import GraphConv, AvgPooling, MaxPooling
from utils import topk, get_batch_id


class SAGPool(torch.nn.Module):
"""The Self-Attention Pooling layer in paper
`Self Attention Graph Pooling <https://arxiv.org/pdf/1904.08082.pdf>`
Args:
in_dim (int): The dimension of node feature.
ratio (float, optional): The pool ratio which determines the amount of nodes
remain after pooling. (default: :obj:`0.5`)
conv_op (torch.nn.Module, optional): The graph convolution layer in dgl used to
compute scale for each node. (default: :obj:`dgl.nn.GraphConv`)
non_linearity (Callable, optional): The non-linearity function, a pytorch function.
(default: :obj:`torch.tanh`)
"""
def __init__(self, in_dim:int, ratio=0.5, conv_op=GraphConv, non_linearity=torch.tanh):
super(SAGPool, self).__init__()
self.in_dim = in_dim
self.ratio = ratio
self.score_layer = conv_op(in_dim, 1)
self.non_linearity = non_linearity

def forward(self, graph:dgl.DGLGraph, feature:torch.Tensor):
score = self.score_layer(graph, feature).squeeze()
perm, next_batch_num_nodes = topk(score, self.ratio, get_batch_id(graph.batch_num_nodes()), graph.batch_num_nodes())
feature = feature[perm] * self.non_linearity(score[perm]).view(-1, 1)
graph = dgl.node_subgraph(graph, perm)

# node_subgraph currently does not support batch-graph,
# the 'batch_num_nodes' of the result subgraph is None.
# So we manually set the 'batch_num_nodes' here.
# Since global pooling has nothing to do with 'batch_num_edges',
# we can leave it to be None or unchanged.
graph.set_batch_num_nodes(next_batch_num_nodes)

return graph, feature, perm


class ConvPoolBlock(torch.nn.Module):
"""A combination of GCN layer and SAGPool layer,
followed by a concatenated (mean||sum) readout operation.
"""
def __init__(self, in_dim:int, out_dim:int, pool_ratio=0.8):
super(ConvPoolBlock, self).__init__()
self.conv = GraphConv(in_dim, out_dim)
self.pool = SAGPool(out_dim, ratio=pool_ratio)
self.avgpool = AvgPooling()
self.maxpool = MaxPooling()

def forward(self, graph, feature):
out = F.relu(self.conv(graph, feature))
graph, out, _ = self.pool(graph, out)
g_out = torch.cat([self.avgpool(graph, out), self.maxpool(graph, out)], dim=-1)
return graph, out, g_out
Loading

0 comments on commit 72ef642

Please sign in to comment.