-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* gnn-explainer * gnn-explainer * gnn-explainer * gnn-explainer * fix * fix * fix * readme * readme Co-authored-by: zhjwy9343 <[email protected]>
- Loading branch information
1 parent
1789972
commit 9aac93f
Showing
10 changed files
with
1,316 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
import torch as th | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import math | ||
|
||
|
||
class NodeExplainerModule(nn.Module): | ||
""" | ||
A Pytorch module for explaining a node's prediction based on its computational graph and node features. | ||
Use two masks: One mask on edges, and another on nodes' features. | ||
So far due to the limit of DGL on edge mask operation, this explainer need the to-be-explained models to | ||
accept an additional input argument, edge mask, and apply this mask in their inner message parse operation. | ||
This is current walk_around to use edge masks. | ||
""" | ||
|
||
# Class inner variables | ||
loss_coef = { | ||
"g_size": 0.05, | ||
"feat_size": 1.0, | ||
"g_ent": 0.1, | ||
"feat_ent": 0.1 | ||
} | ||
|
||
def __init__(self, | ||
model, | ||
num_edges, | ||
node_feat_dim, | ||
activation='sigmoid', | ||
agg_fn='sum', | ||
mask_bias=False): | ||
super(NodeExplainerModule, self).__init__() | ||
self.model = model | ||
self.model.eval() | ||
self.num_edges = num_edges | ||
self.node_feat_dim = node_feat_dim | ||
self.activation = activation | ||
self.agg_fn=agg_fn | ||
self.mask_bias = mask_bias | ||
|
||
# Initialize parameters on masks | ||
self.edge_mask, self.edge_mask_bias = self.create_edge_mask(self.num_edges) | ||
self.node_feat_mask = self.create_node_feat_mask(self.node_feat_dim) | ||
|
||
|
||
def create_edge_mask(self, num_edges, init_strategy='normal', const=1.): | ||
""" | ||
Based on the number of nodes in the computational graph, create a learnable mask of edges. | ||
To adopt to DGL, change this mask from N*N adjacency matrix to the No. of edges | ||
Parameters | ||
---------- | ||
num_edges: Integer N, specify the number of edges. | ||
init_strategy: String, specify the parameter initialization method | ||
const: Float, a value for constant initialization | ||
Returns | ||
------- | ||
mask and mask bias: Tensor, all in shape of N*1 | ||
""" | ||
mask = nn.Parameter(th.Tensor(num_edges, 1)) | ||
|
||
if init_strategy == 'normal': | ||
std = nn.init.calculate_gain("relu") * math.sqrt( | ||
1.0 / num_edges | ||
) | ||
with th.no_grad(): | ||
mask.normal_(1.0, std) | ||
elif init_strategy == "const": | ||
nn.init.constant_(mask, const) | ||
|
||
if self.mask_bias: | ||
mask_bias = nn.Parameter(th.Tensor(num_edges, 1)) | ||
nn.init.constant_(mask_bias, 0.0) | ||
else: | ||
mask_bias = None | ||
|
||
return mask, mask_bias | ||
|
||
|
||
def create_node_feat_mask(self, node_feat_dim, init_strategy="normal"): | ||
""" | ||
Based on the dimensions of node feature in the computational graph, create a learnable mask of features. | ||
Parameters | ||
---------- | ||
node_feat_dim: Integer N, dimensions of node feature | ||
init_strategy: String, specify the parameter initialization method | ||
Returns | ||
------- | ||
mask: Tensor, in shape of N | ||
""" | ||
mask = nn.Parameter(th.Tensor(node_feat_dim)) | ||
|
||
if init_strategy == "normal": | ||
std = 0.1 | ||
with th.no_grad(): | ||
mask.normal_(1.0, std) | ||
elif init_strategy == "constant": | ||
with th.no_grad(): | ||
nn.init.constant_(mask, 0.0) | ||
return mask | ||
|
||
|
||
def forward(self, graph, n_feats): | ||
""" | ||
Calculate prediction results after masking input of the given model. | ||
Parameters | ||
---------- | ||
graph: DGLGraph, Should be a sub_graph of the target node to be explained. | ||
n_idx: Tensor, an integer, index of the node to be explained. | ||
Returns | ||
------- | ||
new_logits: Tensor, in shape of N * Num_Classes | ||
""" | ||
|
||
# Step 1: Mask node feature with the inner feature mask | ||
new_n_feats = n_feats * self.node_feat_mask.sigmoid() | ||
edge_mask = self.edge_mask.sigmoid() | ||
|
||
# Step 2: Add compute logits after mask node features and edges | ||
new_logits = self.model(graph, new_n_feats, edge_mask) | ||
|
||
return new_logits | ||
|
||
|
||
def _loss(self, pred_logits, pred_label): | ||
""" | ||
Compute the losses of this explainer, which include 6 parts in author's codes: | ||
1. The prediction loss between predict logits before and after node and edge masking; | ||
2. Loss of edge mask itself, which tries to put the mask value to either 0 or 1; | ||
3. Loss of node feature mask itself, which tries to put the mask value to either 0 or 1; | ||
4. L2 loss of edge mask weights, but in sum not in mean; | ||
5. L2 loss of node feature mask weights, which is NOT used in the author's codes; | ||
6. Laplacian loss of the adj matrix. | ||
In the PyG implementation, there are 5 types of losses: | ||
1. The prediction loss between logits before and after node and edge masking; | ||
2. Sum loss of edge mask weights; | ||
3. Loss of edge mask entropy, which tries to put the mask value to either 0 or 1; | ||
4. Sum loss of node feature mask weights; | ||
5. Loss of node feature mask entropy, which tries to put the mask value to either 0 or 1; | ||
Parameters | ||
---------- | ||
pred_logits:Tensor, N-dim logits output of model | ||
pred_label: Tensor, N-dim one-hot label of the label | ||
Returns | ||
------- | ||
loss: Scalar, the overall loss of this explainer. | ||
""" | ||
# 1. prediction loss | ||
log_logit = - F.log_softmax(pred_logits, dim=-1) | ||
pred_loss = th.sum(log_logit * pred_label) | ||
|
||
# 2. edge mask loss | ||
if self.activation == 'sigmoid': | ||
edge_mask = th.sigmoid(self.edge_mask) | ||
elif self.activation == 'relu': | ||
edge_mask = F.relu(self.edge_mask) | ||
else: | ||
raise ValueError() | ||
edge_mask_loss = self.loss_coef['g_size'] * th.sum(edge_mask) | ||
|
||
# 3. edge mask entropy loss | ||
edge_ent = -edge_mask * \ | ||
th.log(edge_mask + 1e-8) - \ | ||
(1 - edge_mask) * \ | ||
th.log(1 - edge_mask + 1e-8) | ||
edge_ent_loss = self.loss_coef['g_ent'] * th.mean(edge_ent) | ||
|
||
# 4. node feature mask loss | ||
if self.activation == 'sigmoid': | ||
node_feat_mask = th.sigmoid(self.node_feat_mask) | ||
elif self.activation == 'relu': | ||
node_feat_mask = F.relu(self.node_feat_mask) | ||
else: | ||
raise ValueError() | ||
node_feat_mask_loss = self.loss_coef['feat_size'] * th.sum(node_feat_mask) | ||
|
||
# 5. node feature mask entry loss | ||
node_feat_ent = -node_feat_mask * \ | ||
th.log(node_feat_mask + 1e-8) - \ | ||
(1 - node_feat_mask) * \ | ||
th.log( 1 - node_feat_mask + 1e-8) | ||
node_feat_ent_loss = self.loss_coef['feat_ent'] * th.mean(node_feat_ent) | ||
|
||
total_loss = pred_loss + edge_mask_loss + edge_ent_loss + node_feat_mask_loss + node_feat_ent_loss | ||
|
||
return total_loss | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# DGL Implementation of the GNN Explainer | ||
|
||
This DGL example implements the GNN Explainer model proposed in the paper [GNNExplainer: Generating Explanations for Graph Neural Networks](https://arxiv.org/abs/1903.03894). | ||
The author's codes of implementation is in [here](https://github.com/RexYing/gnn-model-explainer). | ||
|
||
The author's implementation is kind of experimental with experimental codes. So this implementation focuses on a subset of | ||
GNN Explainer's functions, node classification, and later on extend to edge classification. | ||
|
||
Example implementor | ||
---------------------- | ||
This example was implemented by [Jian Zhang](https://github.com/zhjwy9343) and [Kounianhua Du](https://github.com/KounianhuaDu) at the AWS Shanghai AI Lab. | ||
|
||
Dependencies | ||
---------------------- | ||
- numpy 1.19.4 | ||
- pytorch 1.7.1 | ||
- dgl 0.5.3 | ||
- networkx 2.5 | ||
- matplotlib 3.3.4 | ||
|
||
Datasets | ||
---------------------- | ||
Five synthetic datasets used in the paper are used in this example. The generation codes are referenced from the author implementation. | ||
- Syn1 (BA-SHAPES): Start with a base Barabasi-Albert (BA) graph on 300 nodes and a set of 80 five-node “house”-structured network motifs, which are attached to randomly selected nodes of the base graph. The resulting graph is further perturbed by adding 0.01N random edges. Nodes are assigned to 4 classes based on their structural roles. In a house-structured motif, there are 3 types of roles: the top, middle, and bottom node of the house. Therefore there are 4 different classes, corresponding to nodes at the top, middle, bottom of houses, and nodes that do not belong to a house. | ||
- Syn2 (BA-COMMUNITY): A union of two BA-SHAPES graphs. Nodes have normally distributed feature vectors and are assigned to one of 8 classes based on their structural roles and community memberships. | ||
- Syn3 (BA-GRID): The same as BA-SHAPES except that 3-by-3 grid motifs are attached to the base graph in place of house motifs. | ||
- Syn4 (TREE-CYCLE): Start with a base 8-level balanced binary tree and 60 six-node cycle motifs, which are attached to random nodes of the base graph. Perturbed by adding 0.01N random edges. | ||
- Syn5 (TREE-GRID): Start with a base 8-level balanced binary tree and 80 3-by-3 grid motifs, which are attached to random nodes of the base graph. Perturbed by adding 0.1N random edges. | ||
|
||
Demo Usage | ||
---------------------- | ||
**First**, train a demo GNN model by using a synthetic dataset. | ||
``` python | ||
python train_main.py --dataset syn1 | ||
``` | ||
Replace the argument of the --dataset, available options: syn1, syn2, syn3, syn4, syn5 | ||
|
||
This command trains a GNN model and save it to the "dummy_model_syn1.pth" file. | ||
|
||
**Second**, explain the trained model with the same data | ||
``` python | ||
python explain_main.py --dataset syn1 --target_class 1 --hop 2 | ||
``` | ||
Replace the dataset argument value and the target class you want to explain. The code will pick the first node in the specified class to explain. The --hop argument corresponds to the maximum hop number of the computation sub-graph. (For syn1 and syn2, hop=2. For syn3, syn4, and syn5, hop=4.) | ||
|
||
Notice | ||
---------------------- | ||
Because DGL does not support masked adjacency matrix as an input to the forward function of a module. | ||
To use this Explainer, you need to add an edge_weight as the **edge mask** argument to your forward function just like | ||
the dummy model in the models.py file. And you need to change your forward function whenever uses `.update_all` function. | ||
Please use `dgl.function.u_mul_e` to compute the src nodes' features to the edge_weights as the mask method proposed by the | ||
GNN Explainer paper. Check the models.py for details. | ||
|
||
Results | ||
---------------------- | ||
For all the datasets, the first node of target class 1 is picked to be explained. The hop-k computation sub-graph (a compact of 0-hop, 1-hop, ..., k-hop subgraphs) is first extracted and then fed to the models. Followings are the visualization results. Instead of cutting edges that are below the threshold. We use the depth of color of the edges to represent the edge mask weights. The deeper the color of an edge is, the more important the edge is. | ||
|
||
NOTE: We do not perform grid search or finetune here, the visualization results are just for reference. | ||
|
||
|
||
**Syn1 (BA-SHAPES)** | ||
<p align="center"> | ||
<img src="https://github.com/KounianhuaDu/gnn-explainer-dgl-pics/blob/master/imgs/syn1.png" width="600"> | ||
<br> | ||
<b>Figure</b>: Visualization of syn1 dataset (hop=2). | ||
</p> | ||
|
||
**Syn2 (BA-COMMUNITY)** | ||
<p align="center"> | ||
<img src="https://github.com/KounianhuaDu/gnn-explainer-dgl-pics/blob/master/imgs/syn2.png" width="600"> | ||
<br> | ||
<b>Figure</b>: Visualization of syn2 dataset (hop=2). | ||
</p> | ||
|
||
**Syn3 (BA-GRID)** | ||
|
||
For a more explict view, we conduct explaination on both the hop-3 computation sub-graph and the hop-4 computation sub-graph in Syn3 task. | ||
|
||
<p align="center"> | ||
<img src="https://github.com/KounianhuaDu/gnn-explainer-dgl-pics/blob/master/imgs/syn3_3hop.png" width="600"> | ||
<br> | ||
<b>Figure</b>: Visualization of syn3 dataset with hop=3. | ||
</p> | ||
|
||
<p align="center"> | ||
<img src="https://github.com/KounianhuaDu/gnn-explainer-dgl-pics/blob/master/imgs/syn3_4hop.png" width="600"> | ||
<br> | ||
<b>Figure</b>: Visualization of syn3 dataset with hop=4. | ||
</p> | ||
|
||
**Syn4 (TREE-CYCLE)** | ||
<p align="center"> | ||
<img src="https://github.com/KounianhuaDu/gnn-explainer-dgl-pics/blob/master/imgs/syn4.png" width="600"> | ||
<br> | ||
<b>Figure</b>: Visualization of syn4 dataset (hop=4). | ||
</p> | ||
|
||
**Syn5 (TREE-GRID)** | ||
<p align="center"> | ||
<img src="https://github.com/KounianhuaDu/gnn-explainer-dgl-pics/blob/master/imgs/syn5.png" width="600"> | ||
<br> | ||
<b>Figure</b>: Visualization of syn5 dataset (hop=4). | ||
</p> |
Oops, something went wrong.