-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Example] Add sagpool example for pytorch backend (dmlc#2429)
* add sagpool example for pytorch backend * polish sagpool example for pytorch backend * [Example] SAGPool: use std variance * [Example] SAGPool: change to std * add sagpool example to index page * add graph property prediction tag to sagpool Co-authored-by: zhangtianqi <[email protected]>
- Loading branch information
Showing
9 changed files
with
671 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# DGL Implementation of the SAGPool Paper | ||
|
||
This DGL example implements the GNN model proposed in the paper [Self Attention Graph Pooling](https://arxiv.org/pdf/1904.08082.pdf). | ||
The author's codes of implementation is in [here](https://github.com/inyeoplee77/SAGPool) | ||
|
||
|
||
The graph dataset used in this example | ||
--------------------------------------- | ||
The DGL's built-in LegacyTUDataset. This is a serial of graph kernel datasets for graph classification. We use 'DD', 'PROTEINS', 'NCI1', 'NCI109' and 'Mutagenicity' in this SAGPool implementation. All these datasets are randomly splited to train, validation and test set with ratio 0.8, 0.1 and 0.1. | ||
|
||
NOTE: Since there is no data attributes in some of these datasets, we use node_id (in one-hot vector whose length is the max number of nodes across all graphs) as the node feature. Also note that the node_id in some datasets is not unique (e.g. a graph may has two nodes with the same id). | ||
|
||
DD | ||
- NumGraphs: 1178 | ||
- AvgNodesPerGraph: 284.32 | ||
- AvgEdgesPerGraph: 715.66 | ||
- NumFeats: 89 | ||
- NumClasses: 2 | ||
|
||
PROTEINS | ||
- NumGraphs: 1113 | ||
- AvgNodesPerGraph: 39.06 | ||
- AvgEdgesPerGraph: 72.82 | ||
- NumFeats: 1 | ||
- NumClasses: 2 | ||
|
||
NCI1 | ||
- NumGraphs: 4110 | ||
- AvgNodesPerGraph: 29.87 | ||
- AvgEdgesPerGraph: 32.30 | ||
- NumFeats: 37 | ||
- NumClasses: 2 | ||
|
||
NCI109 | ||
- NumGraphs: 4127 | ||
- AvgNodesPerGraph: 29.68 | ||
- AvgEdgesPerGraph: 32.13 | ||
- NumFeats: 38 | ||
- NumClasses: 2 | ||
|
||
Mutagenicity | ||
- NumGraphs: 4337 | ||
- AvgNodesPerGraph: 30.32 | ||
- AvgEdgesPerGraph: 30.77 | ||
- NumFeats: 14 | ||
- NumClasses: 2 | ||
|
||
|
||
How to run example files | ||
-------------------------------- | ||
The valid dataset names (you can find a full list [here](https://chrsmrrs.github.io/datasets/docs/datasets/)): | ||
- 'DD' for D&D | ||
- 'PROTEINS' for PROTEINS | ||
- 'NCI1' for NCI1 | ||
- 'NCI109' for NCI109 | ||
- 'Mutagenicity' for Mutagenicity | ||
|
||
In the sagpool folder, run | ||
|
||
```bash | ||
python main.py --dataset ${your_dataset_name_here} | ||
``` | ||
|
||
If want to use a GPU, run | ||
|
||
```bash | ||
python main.py --device ${your_device_id_here} --dataset ${your_dataset_name_here} | ||
``` | ||
|
||
If your want to perform a grid search, modify parameter settings in `grid_search_config.json` and run | ||
```bash | ||
python grid_search.py --device ${your_device_id_here} --num_trials ${num_of_trials_here} | ||
``` | ||
|
||
Performance | ||
------------------------- | ||
|
||
NOTE: We do not perform grid search or finetune here, so there may be a gap between results in paper and our results. Also, we only perform 10 trials for each experiment, which is different from 200 trials per experiment in the paper. | ||
|
||
**The global architecture result** | ||
| Dataset | paper result (global) | ours (global) | | ||
| ------------- | -------------------------------- | --------------------------- | | ||
| D&D | 76.19 (0.94) | 74.79 (2.69) | | ||
| PROTEINS | 70.04 (1.47) | 70.36 (5.90) | | ||
| NCI1 | 74.18 (1.20) | 72.82 (2.36) | | ||
| NCI109 | 74.06 (0.78) | 71.64 (2.65) | | ||
| Mutagenicity | N/A | 76.55 (2.89) | | ||
|
||
**The hierarchical architecture result** | ||
| Dataset | paper result (hierarchical) | ours (hierarchical) | | ||
| ------------- | -------------------------------- | --------------------------- | | ||
| D&D | 76.45 (0.97) | 75.38 (4.17) | | ||
| PROTEINS | 71.86 (0.97) | 70.36 (5.68) | | ||
| NCI1 | 67.45 (1.11) | 70.61 (2.25) | | ||
| NCI109 | 67.86 (1.41) | 69.13 (3.85) | | ||
| Mutagenicity | N/A | 75.20 (1.95) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import torch.utils.data | ||
from torch.utils.data.dataloader import DataLoader | ||
import dgl | ||
import numpy as np | ||
|
||
|
||
def collate_fn(batch): | ||
""" | ||
collate_fn for dataset batching | ||
transform ndata to tensor (in gpu is available) | ||
""" | ||
graphs, labels = map(list, zip(*batch)) | ||
|
||
# batch graphs and cast to PyTorch tensor | ||
for graph in graphs: | ||
for (key, value) in graph.ndata.items(): | ||
graph.ndata[key] = value.float() | ||
batched_graphs = dgl.batch(graphs) | ||
|
||
# cast to PyTorch tensor | ||
batched_labels = torch.LongTensor(np.array(labels)) | ||
|
||
return batched_graphs, batched_labels | ||
|
||
|
||
class GraphDataLoader(DataLoader): | ||
def __init__(self, dataset, batch_size=1, shuffle=False, **kwargs): | ||
super(GraphDataLoader, self).__init__(dataset, batch_size, shuffle, | ||
collate_fn=collate_fn, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import json | ||
import os | ||
from copy import deepcopy | ||
from main import main, parse_args | ||
from utils import get_stats | ||
|
||
def load_config(path="./grid_search_config.json"): | ||
with open(path, "r") as f: | ||
return json.load(f) | ||
|
||
def run_experiments(args): | ||
res = [] | ||
for i in range(args.num_trials): | ||
print("Trial {}/{}".format(i + 1, args.num_trials)) | ||
acc, _ = main(args) | ||
res.append(acc) | ||
|
||
mean, err_bd = get_stats(res, conf_interval=True) | ||
return mean, err_bd | ||
|
||
def grid_search(config:dict): | ||
args = parse_args() | ||
results = {} | ||
|
||
for d in config["dataset"]: | ||
args.dataset = d | ||
best_acc, err_bd = 0., 0. | ||
best_args = vars(args) | ||
for arch in config["arch"]: | ||
args.architecture = arch | ||
for hidden in config["hidden"]: | ||
args.hid_dim = hidden | ||
for pool_ratio in config["pool_ratio"]: | ||
args.pool_ratio = pool_ratio | ||
for lr in config["lr"]: | ||
args.lr = lr | ||
for weight_decay in config["weight_decay"]: | ||
args.weight_decay = weight_decay | ||
acc, bd = run_experiments(args) | ||
if acc > best_acc: | ||
best_acc = acc | ||
err_bd = bd | ||
best_args = deepcopy(vars(args)) | ||
args.output_path = "./output" | ||
if not os.path.exists(args.output_path): | ||
os.makedirs(args.output_path) | ||
args.output_path = "./output/{}.log".format(d) | ||
result = { | ||
"params": best_args, | ||
"result": "{:.4f}({:.4f})".format(best_acc, err_bd) | ||
} | ||
with open(args.output_path, "w") as f: | ||
json.dump(result, f, sort_keys=True, indent=4) | ||
|
||
grid_search(load_config()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
{ | ||
"arch": ["hierarchical", "global"], | ||
"hidden": [16, 32, 64, 128], | ||
"pool_ratio": [0.25, 0.5], | ||
"lr": [1e-2, 5e-2, 1e-3, 5e-3, 1e-4, 5e-4], | ||
"weight_decay": [1e-2, 1e-3, 1e-4, 1e-5], | ||
"dataset": ["DD", "PROTEINS", "NCI1", "NCI109", "Mutagenicity"] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
import dgl | ||
from dgl.nn import GraphConv, AvgPooling, MaxPooling | ||
from utils import topk, get_batch_id | ||
|
||
|
||
class SAGPool(torch.nn.Module): | ||
"""The Self-Attention Pooling layer in paper | ||
`Self Attention Graph Pooling <https://arxiv.org/pdf/1904.08082.pdf>` | ||
Args: | ||
in_dim (int): The dimension of node feature. | ||
ratio (float, optional): The pool ratio which determines the amount of nodes | ||
remain after pooling. (default: :obj:`0.5`) | ||
conv_op (torch.nn.Module, optional): The graph convolution layer in dgl used to | ||
compute scale for each node. (default: :obj:`dgl.nn.GraphConv`) | ||
non_linearity (Callable, optional): The non-linearity function, a pytorch function. | ||
(default: :obj:`torch.tanh`) | ||
""" | ||
def __init__(self, in_dim:int, ratio=0.5, conv_op=GraphConv, non_linearity=torch.tanh): | ||
super(SAGPool, self).__init__() | ||
self.in_dim = in_dim | ||
self.ratio = ratio | ||
self.score_layer = conv_op(in_dim, 1) | ||
self.non_linearity = non_linearity | ||
|
||
def forward(self, graph:dgl.DGLGraph, feature:torch.Tensor): | ||
score = self.score_layer(graph, feature).squeeze() | ||
perm, next_batch_num_nodes = topk(score, self.ratio, get_batch_id(graph.batch_num_nodes()), graph.batch_num_nodes()) | ||
feature = feature[perm] * self.non_linearity(score[perm]).view(-1, 1) | ||
graph = dgl.node_subgraph(graph, perm) | ||
|
||
# node_subgraph currently does not support batch-graph, | ||
# the 'batch_num_nodes' of the result subgraph is None. | ||
# So we manually set the 'batch_num_nodes' here. | ||
# Since global pooling has nothing to do with 'batch_num_edges', | ||
# we can leave it to be None or unchanged. | ||
graph.set_batch_num_nodes(next_batch_num_nodes) | ||
|
||
return graph, feature, perm | ||
|
||
|
||
class ConvPoolBlock(torch.nn.Module): | ||
"""A combination of GCN layer and SAGPool layer, | ||
followed by a concatenated (mean||sum) readout operation. | ||
""" | ||
def __init__(self, in_dim:int, out_dim:int, pool_ratio=0.8): | ||
super(ConvPoolBlock, self).__init__() | ||
self.conv = GraphConv(in_dim, out_dim) | ||
self.pool = SAGPool(out_dim, ratio=pool_ratio) | ||
self.avgpool = AvgPooling() | ||
self.maxpool = MaxPooling() | ||
|
||
def forward(self, graph, feature): | ||
out = F.relu(self.conv(graph, feature)) | ||
graph, out, _ = self.pool(graph, out) | ||
g_out = torch.cat([self.avgpool(graph, out), self.maxpool(graph, out)], dim=-1) | ||
return graph, out, g_out |
Oops, something went wrong.