Skip to content

Commit

Permalink
Gnnexplainer (dmlc#2717)
Browse files Browse the repository at this point in the history
* gnn-explainer

* gnn-explainer

* gnn-explainer

* gnn-explainer

* fix

* fix

* fix

* readme

* readme

Co-authored-by: zhjwy9343 <[email protected]>
  • Loading branch information
KounianhuaDu and zhjwy9343 authored Mar 22, 2021
1 parent 1789972 commit 9aac93f
Show file tree
Hide file tree
Showing 10 changed files with 1,316 additions and 1 deletion.
5 changes: 4 additions & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ The folder contains example implementations of selected research papers related
| [Directional Message Passing for Molecular Graphs](#dimenet) | | | :heavy_check_mark: | | |
| [Link Prediction Based on Graph Neural Networks](#seal) | | :heavy_check_mark: | | :heavy_check_mark: | :heavy_check_mark: |
| [Variational Graph Auto-Encoders](#vgae) | | :heavy_check_mark: | | | |

| [GNNExplainer: Generating Explanations for Graph Neural Networks](#gnnexplainer) | :heavy_check_mark: | | | | |
## 2020

- <a name="grand"></a> Feng et al. Graph Random Neural Network for Semi-Supervised Learning on Graphs. [Paper link](https://arxiv.org/abs/2005.11079).
Expand Down Expand Up @@ -190,6 +190,9 @@ The folder contains example implementations of selected research papers related
- <a name='ngcf'></a> Wang, Xiang, et al. Neural Graph Collaborative Filtering. [Paper link](https://arxiv.org/abs/1905.08108).
- Example code: [PyTorch](../examples/pytorch/NGCF)
- Tags: Collaborative Filtering, Recommendation, Graph Neural Network
- <a name='gnnexplainer'></a> Ying, Rex, et al. GNNExplainer: Generating Explanations for Graph Neural Networks. [Paper link](https://arxiv.org/abs/1903.03894).
- Example code: [PyTorch](../examples/pytorch/gnn_explainer)
- Tags: Graph Neural Network, Explainability


## 2018
Expand Down
200 changes: 200 additions & 0 deletions examples/pytorch/gnn_explainer/NodeExplainerModule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import math


class NodeExplainerModule(nn.Module):
"""
A Pytorch module for explaining a node's prediction based on its computational graph and node features.
Use two masks: One mask on edges, and another on nodes' features.
So far due to the limit of DGL on edge mask operation, this explainer need the to-be-explained models to
accept an additional input argument, edge mask, and apply this mask in their inner message parse operation.
This is current walk_around to use edge masks.
"""

# Class inner variables
loss_coef = {
"g_size": 0.05,
"feat_size": 1.0,
"g_ent": 0.1,
"feat_ent": 0.1
}

def __init__(self,
model,
num_edges,
node_feat_dim,
activation='sigmoid',
agg_fn='sum',
mask_bias=False):
super(NodeExplainerModule, self).__init__()
self.model = model
self.model.eval()
self.num_edges = num_edges
self.node_feat_dim = node_feat_dim
self.activation = activation
self.agg_fn=agg_fn
self.mask_bias = mask_bias

# Initialize parameters on masks
self.edge_mask, self.edge_mask_bias = self.create_edge_mask(self.num_edges)
self.node_feat_mask = self.create_node_feat_mask(self.node_feat_dim)


def create_edge_mask(self, num_edges, init_strategy='normal', const=1.):
"""
Based on the number of nodes in the computational graph, create a learnable mask of edges.
To adopt to DGL, change this mask from N*N adjacency matrix to the No. of edges
Parameters
----------
num_edges: Integer N, specify the number of edges.
init_strategy: String, specify the parameter initialization method
const: Float, a value for constant initialization
Returns
-------
mask and mask bias: Tensor, all in shape of N*1
"""
mask = nn.Parameter(th.Tensor(num_edges, 1))

if init_strategy == 'normal':
std = nn.init.calculate_gain("relu") * math.sqrt(
1.0 / num_edges
)
with th.no_grad():
mask.normal_(1.0, std)
elif init_strategy == "const":
nn.init.constant_(mask, const)

if self.mask_bias:
mask_bias = nn.Parameter(th.Tensor(num_edges, 1))
nn.init.constant_(mask_bias, 0.0)
else:
mask_bias = None

return mask, mask_bias


def create_node_feat_mask(self, node_feat_dim, init_strategy="normal"):
"""
Based on the dimensions of node feature in the computational graph, create a learnable mask of features.
Parameters
----------
node_feat_dim: Integer N, dimensions of node feature
init_strategy: String, specify the parameter initialization method
Returns
-------
mask: Tensor, in shape of N
"""
mask = nn.Parameter(th.Tensor(node_feat_dim))

if init_strategy == "normal":
std = 0.1
with th.no_grad():
mask.normal_(1.0, std)
elif init_strategy == "constant":
with th.no_grad():
nn.init.constant_(mask, 0.0)
return mask


def forward(self, graph, n_feats):
"""
Calculate prediction results after masking input of the given model.
Parameters
----------
graph: DGLGraph, Should be a sub_graph of the target node to be explained.
n_idx: Tensor, an integer, index of the node to be explained.
Returns
-------
new_logits: Tensor, in shape of N * Num_Classes
"""

# Step 1: Mask node feature with the inner feature mask
new_n_feats = n_feats * self.node_feat_mask.sigmoid()
edge_mask = self.edge_mask.sigmoid()

# Step 2: Add compute logits after mask node features and edges
new_logits = self.model(graph, new_n_feats, edge_mask)

return new_logits


def _loss(self, pred_logits, pred_label):
"""
Compute the losses of this explainer, which include 6 parts in author's codes:
1. The prediction loss between predict logits before and after node and edge masking;
2. Loss of edge mask itself, which tries to put the mask value to either 0 or 1;
3. Loss of node feature mask itself, which tries to put the mask value to either 0 or 1;
4. L2 loss of edge mask weights, but in sum not in mean;
5. L2 loss of node feature mask weights, which is NOT used in the author's codes;
6. Laplacian loss of the adj matrix.
In the PyG implementation, there are 5 types of losses:
1. The prediction loss between logits before and after node and edge masking;
2. Sum loss of edge mask weights;
3. Loss of edge mask entropy, which tries to put the mask value to either 0 or 1;
4. Sum loss of node feature mask weights;
5. Loss of node feature mask entropy, which tries to put the mask value to either 0 or 1;
Parameters
----------
pred_logits:Tensor, N-dim logits output of model
pred_label: Tensor, N-dim one-hot label of the label
Returns
-------
loss: Scalar, the overall loss of this explainer.
"""
# 1. prediction loss
log_logit = - F.log_softmax(pred_logits, dim=-1)
pred_loss = th.sum(log_logit * pred_label)

# 2. edge mask loss
if self.activation == 'sigmoid':
edge_mask = th.sigmoid(self.edge_mask)
elif self.activation == 'relu':
edge_mask = F.relu(self.edge_mask)
else:
raise ValueError()
edge_mask_loss = self.loss_coef['g_size'] * th.sum(edge_mask)

# 3. edge mask entropy loss
edge_ent = -edge_mask * \
th.log(edge_mask + 1e-8) - \
(1 - edge_mask) * \
th.log(1 - edge_mask + 1e-8)
edge_ent_loss = self.loss_coef['g_ent'] * th.mean(edge_ent)

# 4. node feature mask loss
if self.activation == 'sigmoid':
node_feat_mask = th.sigmoid(self.node_feat_mask)
elif self.activation == 'relu':
node_feat_mask = F.relu(self.node_feat_mask)
else:
raise ValueError()
node_feat_mask_loss = self.loss_coef['feat_size'] * th.sum(node_feat_mask)

# 5. node feature mask entry loss
node_feat_ent = -node_feat_mask * \
th.log(node_feat_mask + 1e-8) - \
(1 - node_feat_mask) * \
th.log( 1 - node_feat_mask + 1e-8)
node_feat_ent_loss = self.loss_coef['feat_ent'] * th.mean(node_feat_ent)

total_loss = pred_loss + edge_mask_loss + edge_ent_loss + node_feat_mask_loss + node_feat_ent_loss

return total_loss

103 changes: 103 additions & 0 deletions examples/pytorch/gnn_explainer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# DGL Implementation of the GNN Explainer

This DGL example implements the GNN Explainer model proposed in the paper [GNNExplainer: Generating Explanations for Graph Neural Networks](https://arxiv.org/abs/1903.03894).
The author's codes of implementation is in [here](https://github.com/RexYing/gnn-model-explainer).

The author's implementation is kind of experimental with experimental codes. So this implementation focuses on a subset of
GNN Explainer's functions, node classification, and later on extend to edge classification.

Example implementor
----------------------
This example was implemented by [Jian Zhang](https://github.com/zhjwy9343) and [Kounianhua Du](https://github.com/KounianhuaDu) at the AWS Shanghai AI Lab.

Dependencies
----------------------
- numpy 1.19.4
- pytorch 1.7.1
- dgl 0.5.3
- networkx 2.5
- matplotlib 3.3.4

Datasets
----------------------
Five synthetic datasets used in the paper are used in this example. The generation codes are referenced from the author implementation.
- Syn1 (BA-SHAPES): Start with a base Barabasi-Albert (BA) graph on 300 nodes and a set of 80 five-node “house”-structured network motifs, which are attached to randomly selected nodes of the base graph. The resulting graph is further perturbed by adding 0.01N random edges. Nodes are assigned to 4 classes based on their structural roles. In a house-structured motif, there are 3 types of roles: the top, middle, and bottom node of the house. Therefore there are 4 different classes, corresponding to nodes at the top, middle, bottom of houses, and nodes that do not belong to a house.
- Syn2 (BA-COMMUNITY): A union of two BA-SHAPES graphs. Nodes have normally distributed feature vectors and are assigned to one of 8 classes based on their structural roles and community memberships.
- Syn3 (BA-GRID): The same as BA-SHAPES except that 3-by-3 grid motifs are attached to the base graph in place of house motifs.
- Syn4 (TREE-CYCLE): Start with a base 8-level balanced binary tree and 60 six-node cycle motifs, which are attached to random nodes of the base graph. Perturbed by adding 0.01N random edges.
- Syn5 (TREE-GRID): Start with a base 8-level balanced binary tree and 80 3-by-3 grid motifs, which are attached to random nodes of the base graph. Perturbed by adding 0.1N random edges.

Demo Usage
----------------------
**First**, train a demo GNN model by using a synthetic dataset.
``` python
python train_main.py --dataset syn1
```
Replace the argument of the --dataset, available options: syn1, syn2, syn3, syn4, syn5

This command trains a GNN model and save it to the "dummy_model_syn1.pth" file.

**Second**, explain the trained model with the same data
``` python
python explain_main.py --dataset syn1 --target_class 1 --hop 2
```
Replace the dataset argument value and the target class you want to explain. The code will pick the first node in the specified class to explain. The --hop argument corresponds to the maximum hop number of the computation sub-graph. (For syn1 and syn2, hop=2. For syn3, syn4, and syn5, hop=4.)

Notice
----------------------
Because DGL does not support masked adjacency matrix as an input to the forward function of a module.
To use this Explainer, you need to add an edge_weight as the **edge mask** argument to your forward function just like
the dummy model in the models.py file. And you need to change your forward function whenever uses `.update_all` function.
Please use `dgl.function.u_mul_e` to compute the src nodes' features to the edge_weights as the mask method proposed by the
GNN Explainer paper. Check the models.py for details.

Results
----------------------
For all the datasets, the first node of target class 1 is picked to be explained. The hop-k computation sub-graph (a compact of 0-hop, 1-hop, ..., k-hop subgraphs) is first extracted and then fed to the models. Followings are the visualization results. Instead of cutting edges that are below the threshold. We use the depth of color of the edges to represent the edge mask weights. The deeper the color of an edge is, the more important the edge is.

NOTE: We do not perform grid search or finetune here, the visualization results are just for reference.


**Syn1 (BA-SHAPES)**
<p align="center">
<img src="https://github.com/KounianhuaDu/gnn-explainer-dgl-pics/blob/master/imgs/syn1.png" width="600">
<br>
<b>Figure</b>: Visualization of syn1 dataset (hop=2).
</p>

**Syn2 (BA-COMMUNITY)**
<p align="center">
<img src="https://github.com/KounianhuaDu/gnn-explainer-dgl-pics/blob/master/imgs/syn2.png" width="600">
<br>
<b>Figure</b>: Visualization of syn2 dataset (hop=2).
</p>

**Syn3 (BA-GRID)**

For a more explict view, we conduct explaination on both the hop-3 computation sub-graph and the hop-4 computation sub-graph in Syn3 task.

<p align="center">
<img src="https://github.com/KounianhuaDu/gnn-explainer-dgl-pics/blob/master/imgs/syn3_3hop.png" width="600">
<br>
<b>Figure</b>: Visualization of syn3 dataset with hop=3.
</p>

<p align="center">
<img src="https://github.com/KounianhuaDu/gnn-explainer-dgl-pics/blob/master/imgs/syn3_4hop.png" width="600">
<br>
<b>Figure</b>: Visualization of syn3 dataset with hop=4.
</p>

**Syn4 (TREE-CYCLE)**
<p align="center">
<img src="https://github.com/KounianhuaDu/gnn-explainer-dgl-pics/blob/master/imgs/syn4.png" width="600">
<br>
<b>Figure</b>: Visualization of syn4 dataset (hop=4).
</p>

**Syn5 (TREE-GRID)**
<p align="center">
<img src="https://github.com/KounianhuaDu/gnn-explainer-dgl-pics/blob/master/imgs/syn5.png" width="600">
<br>
<b>Figure</b>: Visualization of syn5 dataset (hop=4).
</p>
Loading

0 comments on commit 9aac93f

Please sign in to comment.