Skip to content

Commit

Permalink
[Example] Refactor GNNExplainer Example (dmlc#4560)
Browse files Browse the repository at this point in the history
* debug

* debug

* readme

* fix readme

* fix readme

* Update

* Update

* update

* fix bug of syn2

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: Mufei Li <[email protected]>
  • Loading branch information
3 people authored Sep 21, 2022
1 parent 880b3b1 commit ec4271b
Show file tree
Hide file tree
Showing 15 changed files with 142 additions and 1,227 deletions.
200 changes: 0 additions & 200 deletions examples/pytorch/gnn_explainer/NodeExplainerModule.py

This file was deleted.

117 changes: 38 additions & 79 deletions examples/pytorch/gnn_explainer/README.md
Original file line number Diff line number Diff line change
@@ -1,103 +1,62 @@
# DGL Implementation of the GNN Explainer
# DGL Implementation of GNNExplainer

This DGL example implements the GNN Explainer model proposed in the paper [GNNExplainer: Generating Explanations for Graph Neural Networks](https://arxiv.org/abs/1903.03894).
The author's codes of implementation is in [here](https://github.com/RexYing/gnn-model-explainer).
This is a DGL example for [GNNExplainer: Generating Explanations for Graph Neural Networks](https://arxiv.org/abs/1903.03894). For the authors' original implementation,
see [here](https://github.com/RexYing/gnn-model-explainer).

The author's implementation is kind of experimental with experimental codes. So this implementation focuses on a subset of
GNN Explainer's functions, node classification, and later on extend to edge classification.
Contributors:
- [Jian Zhang](https://github.com/zhjwy9343)
- [Kounianhua Du](https://github.com/KounianhuaDu)
- [Yanjun Zhao](https://github.com/zyj-111)

Example implementor
Datasets
----------------------
This example was implemented by [Jian Zhang](https://github.com/zhjwy9343) and [Kounianhua Du](https://github.com/KounianhuaDu) at the AWS Shanghai AI Lab.

Dependencies
----------------------
- numpy 1.19.4
- pytorch 1.7.1
- dgl 0.5.3
- networkx 2.5
- matplotlib 3.3.4
Four built-in synthetic datasets are used in this example.

Datasets
----------------------
Five synthetic datasets used in the paper are used in this example. The generation codes are referenced from the author implementation.
- Syn1 (BA-SHAPES): Start with a base Barabasi-Albert (BA) graph on 300 nodes and a set of 80 five-node “house”-structured network motifs, which are attached to randomly selected nodes of the base graph. The resulting graph is further perturbed by adding 0.01N random edges. Nodes are assigned to 4 classes based on their structural roles. In a house-structured motif, there are 3 types of roles: the top, middle, and bottom node of the house. Therefore there are 4 different classes, corresponding to nodes at the top, middle, bottom of houses, and nodes that do not belong to a house.
- Syn2 (BA-COMMUNITY): A union of two BA-SHAPES graphs. Nodes have normally distributed feature vectors and are assigned to one of 8 classes based on their structural roles and community memberships.
- Syn3 (BA-GRID): The same as BA-SHAPES except that 3-by-3 grid motifs are attached to the base graph in place of house motifs.
- Syn4 (TREE-CYCLE): Start with a base 8-level balanced binary tree and 60 six-node cycle motifs, which are attached to random nodes of the base graph. Perturbed by adding 0.01N random edges.
- Syn5 (TREE-GRID): Start with a base 8-level balanced binary tree and 80 3-by-3 grid motifs, which are attached to random nodes of the base graph. Perturbed by adding 0.1N random edges.

Demo Usage
- [BA-SHAPES](https://docs.dgl.ai/generated/dgl.data.BAShapeDataset.html#dgl.data.BAShapeDataset)
- [BA-COMMUNITY](https://docs.dgl.ai/generated/dgl.data.BACommunityDataset.html#dgl.data.BACommunityDataset)
- [TREE-CYCLE](https://docs.dgl.ai/generated/dgl.data.TreeCycleDataset.html#dgl.data.TreeCycleDataset)
- [TREE-GRID](https://docs.dgl.ai/generated/dgl.data.TreeGridDataset.html#dgl.data.TreeGridDataset)

Usage
----------------------
**First**, train a demo GNN model by using a synthetic dataset.
``` python
python train_main.py --dataset syn1
```
Replace the argument of the --dataset, available options: syn1, syn2, syn3, syn4, syn5

This command trains a GNN model and save it to the "dummy_model_syn1.pth" file.
**First**, train a GNN model on a dataset.

**Second**, explain the trained model with the same data
``` python
python explain_main.py --dataset syn1 --target_class 1 --hop 2
```bash
python train_main.py --dataset $DATASET
```
Replace the dataset argument value and the target class you want to explain. The code will pick the first node in the specified class to explain. The --hop argument corresponds to the maximum hop number of the computation sub-graph. (For syn1 and syn2, hop=2. For syn3, syn4, and syn5, hop=4.)

Notice
----------------------
Because DGL does not support masked adjacency matrix as an input to the forward function of a module.
To use this Explainer, you need to add an edge_weight as the **edge mask** argument to your forward function just like
the dummy model in the models.py file. And you need to change your forward function whenever uses `.update_all` function.
Please use `dgl.function.u_mul_e` to compute the src nodes' features to the edge_weights as the mask method proposed by the
GNN Explainer paper. Check the models.py for details.
Valid options for `$DATASET`: `BAShape`, `BACommunity`, `TreeCycle`, `TreeGrid`

Results
----------------------
For all the datasets, the first node of target class 1 is picked to be explained. The hop-k computation sub-graph (a compact of 0-hop, 1-hop, ..., k-hop subgraphs) is first extracted and then fed to the models. Followings are the visualization results. Instead of cutting edges that are below the threshold. We use the depth of color of the edges to represent the edge mask weights. The deeper the color of an edge is, the more important the edge is.
The trained model weights will be saved to `model_{dataset}.pth`

NOTE: We do not perform grid search or finetune here, the visualization results are just for reference.
**Second**, install [GNNLens2](https://github.com/dmlc/GNNLens2) with

```bash
pip install -U flask-cors
pip install Flask==2.0.3
pip install gnnlens
```

**Syn1 (BA-SHAPES)**
<p align="center">
<img src="https://github.com/KounianhuaDu/gnn-explainer-dgl-pics/blob/master/imgs/syn1.png" width="600">
<br>
<b>Figure</b>: Visualization of syn1 dataset (hop=2).
</p>
**Third**, explain the trained model with the same dataset

**Syn2 (BA-COMMUNITY)**
<p align="center">
<img src="https://github.com/KounianhuaDu/gnn-explainer-dgl-pics/blob/master/imgs/syn2.png" width="600">
<br>
<b>Figure</b>: Visualization of syn2 dataset (hop=2).
</p>
```bash
python explain_main.py --dataset $DATASET
```

**Syn3 (BA-GRID)**
**Finally**, launch `GNNLens2` to visualize the explanations

For a more explict view, we conduct explaination on both the hop-3 computation sub-graph and the hop-4 computation sub-graph in Syn3 task.
```bash
gnnlens --logdir gnn_subgraph
```

<p align="center">
<img src="https://github.com/KounianhuaDu/gnn-explainer-dgl-pics/blob/master/imgs/syn3_3hop.png" width="600">
<br>
<b>Figure</b>: Visualization of syn3 dataset with hop=3.
</p>
By entering `localhost:7777` in your web browser address bar, you can see the GNNLens2 interface. `7777` is the default port GNNLens2 uses. You can specify an alternative one by adding `--port xxxx` after the command line and change the address in the web browser accordingly.

<p align="center">
<img src="https://github.com/KounianhuaDu/gnn-explainer-dgl-pics/blob/master/imgs/syn3_4hop.png" width="600">
<br>
<b>Figure</b>: Visualization of syn3 dataset with hop=4.
</p>

**Syn4 (TREE-CYCLE)**
<p align="center">
<img src="https://github.com/KounianhuaDu/gnn-explainer-dgl-pics/blob/master/imgs/syn4.png" width="600">
<br>
<b>Figure</b>: Visualization of syn4 dataset (hop=4).
</p>
A sample visualization is available below. For more details of using `GNNLens2`, check its [tutorials](https://github.com/dmlc/GNNLens2#tutorials).

**Syn5 (TREE-GRID)**
<p align="center">
<img src="https://github.com/KounianhuaDu/gnn-explainer-dgl-pics/blob/master/imgs/syn5.png" width="600">
<img src="https://data.dgl.ai/asset/image/explain_BAShape.png" width="600">
<br>
<b>Figure</b>: Visualization of syn5 dataset (hop=4).
<b>Figure</b>: Explanation for node 41 of BAShape
</p>
Loading

0 comments on commit ec4271b

Please sign in to comment.