Skip to content

Commit

Permalink
sample large graphs
Browse files Browse the repository at this point in the history
  • Loading branch information
k-amara committed May 30, 2022
1 parent 9a07d1c commit 8f59c08
Show file tree
Hide file tree
Showing 15 changed files with 1,936 additions and 1,342 deletions.
32 changes: 18 additions & 14 deletions code/explainer/node_explainer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import os
import os
import random

import networkx as nx
import numpy as np
import random
import torch
from captum.attr import IntegratedGradients, LayerGradCam, Saliency
from gnn.model import GraphConv, GraphConvolution
from torch.autograd import Variable
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import to_networkx
from utils.gen_utils import get_subgraph, sample_large_graph

from explainer.gnnexplainer import GNNExplainer, TargetedGNNExplainer
from explainer.gnnlrp import GNN_LRP
Expand Down Expand Up @@ -220,17 +222,19 @@ def explain_subgraphx_node(model, data, node_idx, x, edge_index, edge_weight, ta
def explain_zorro_node(model, data, node_idx, x, edge_index, edge_weight, target, device, args, include_edges=None):
zorro = Zorro(model, device, num_hops = args.num_gc_layers)
print('explain node', zorro.explain_node(node_idx, x, edge_index))
selected_nodes, selected_features, executed_selection = zorro.explain_node(node_idx, x, edge_index)[0]
selected_nodes = torch.Tensor(selected_nodes.squeeze())
selected_features = torch.Tensor(selected_features.squeeze())
print("node_attrs", selected_nodes)
print("node_feature_mask", selected_features)
print("executed_selection", executed_selection)
node_attr = np.array(selected_nodes)
edge_mask = node_attr_to_edge(edge_index, node_attr)
edge_mask = edge_mask.cpu().detach().numpy()
node_feature_mask = node_feature_mask.cpu().detach().numpy()
return edge_mask, node_feature_mask
explanation = zorro.explain_node(node_idx, x, edge_index, tau=0.85, recursion_depth=3)
print('explanation', explanation)
#selected_nodes, selected_features, executed_selection = zorro.explain_node(node_idx, x, edge_index, tau=0.85, recursion_depth=4)
#selected_nodes = torch.Tensor(selected_nodes.squeeze())
#selected_features = torch.Tensor(selected_features.squeeze())
#print("node_attrs", selected_nodes)
#print("node_feature_mask", selected_features)
#print("executed_selection", executed_selection)
#node_attr = np.array(selected_nodes)
#edge_mask = node_attr_to_edge(edge_index, node_attr)
#edge_mask = edge_mask.cpu().detach().numpy()
#node_feature_mask = node_feature_mask.cpu().detach().numpy()
return #edge_mask, node_feature_mask



Expand All @@ -247,13 +251,13 @@ def explain_pgexplainer_node(model, data, node_idx, x, edge_index, edge_weight,
state_dict = torch.load(pgexplainer_saving_path)
pgexplainer.load_state_dict(state_dict)
else:
data = sample_large_graph(data)
pgexplainer.train_explanation_network(data)
print("Save PGExplainer model...")
torch.save(pgexplainer.state_dict(), pgexplainer_saving_path)
state_dict = torch.load(pgexplainer_saving_path)
pgexplainer.load_state_dict(state_dict)


edge_mask = pgexplainer.explain_node(model, node_idx, x, edge_index)
edge_mask = edge_mask.cpu().detach().numpy()
return edge_mask, None
Expand Down
4 changes: 2 additions & 2 deletions code/explainer/zorro.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

class Zorro(torch.nn.Module):

def __init__(self, model, device, num_hops=None, log=True, greedy=True, record_process_time=False, add_noise=False, samples=10):
def __init__(self, model, device, num_hops=None, log=True, greedy=True, record_process_time=False, add_noise=False, samples=100):
super(Zorro, self).__init__()
self.model = model
self.log = log
Expand Down Expand Up @@ -427,7 +427,7 @@ def recursively_get_minimal_sets(self, initial_distortion, tau, possible_nodes,

return minimal_nodes_and_features_sets

def explain_node(self, node_idx, full_feature_matrix, edge_index, tau=0.15, recursion_depth=np.inf,
def explain_node(self, node_idx, full_feature_matrix, edge_index, tau=0.85, recursion_depth=np.inf,
save_initial_improve=False):
r"""Learns and returns a node feature mask and an edge mask that play a
crucial role to explain the prediction made by the GNN for node
Expand Down
3 changes: 1 addition & 2 deletions code/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from gnn.eval import gnn_scores_gc, gnn_scores_nc, gnn_accuracy
from gnn.model import GCN, GcnEncoderGraph, GcnEncoderNode
from gnn.train import train_graph_classification, train_node_classification, train_real
from utils.gen_utils import gen_dataloader, get_labels, get_test_graphs, get_test_nodes
from utils.gen_utils import gen_dataloader, get_labels, get_test_graphs, get_test_nodes, sample_large_graph
from utils.graph_utils import get_edge_index_batch, split_batch
from utils.io_utils import check_dir, create_data_filename, create_mask_filename, create_model_filename, load_ckpt, save_checkpoint
from utils.parser_utils import arg_parse, get_data_args, get_graph_size_args
Expand Down Expand Up @@ -116,7 +116,6 @@ def main_real(args):
### Explainer ###
list_test_nodes = get_test_nodes(data, model, args)
mask_filename = create_mask_filename(args)

edge_masks, node_feat_masks, Time = compute_edge_masks_nc(list_test_nodes, model, data, device, args)
"""
if args.dataset.startswith("ebay"):
Expand Down
10 changes: 10 additions & 0 deletions code/utils/gen_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@ def list_to_dict(preds):
preds_dict[key] = np.array(preds_dict[key])
return preds_dict

def sample_large_graph(data):
if data.num_edges > 50000:
print("Too many edges, sampling large graph...")
node_idx = random.randint(0, data.num_nodes - 1)
x, edge_index, mapping, edge_mask, subset, kwargs = get_subgraph(node_idx, data.x, data.edge_index, num_hops=3)
data = data.subgraph(subset)
print(f'Sample size: {data.num_nodes} nodes and {data.num_edges} edges')
return data



def get_subgraph(node_idx, x, edge_index, num_hops, **kwargs):
num_nodes, num_edges = x.size(0), edge_index.size(1)
Expand Down
15 changes: 1 addition & 14 deletions configs/real/config_real_5expe_testpred.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,6 @@
"dataset": ["cora", "facebook", "cornell", "actor", "chameleon"],
"strategy": "topk",
"params_list": "1,2,5,10,15,20,100",
"explainer_name": [
"random",
"distance",
"pagerank",
"sa",
"ig",
"gradcam",
"occlusion",
"basic_gnnexplainer",
"gnnexplainer",
"pgmexplainer",
"pgexplainer",
"subgraphx"
]
"explainer_name": ["random", "pgexplainer"]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,18 @@
},
"meta": {
"group": "checkpoints/node_classification/real/topk",
"name": "test",
"name": "topk_real_5",
"dest-arg": "yes",
"dest-name": "dest"
},
"params": {
"data_save_dir": "data",
"seed": 0,
"num_test": 1,
"seed": [0, 1, 2, 3, 4],
"num_test": 100,
"explain_graph": "False",
"hard_mask": "False",
"true_label_as_target": "True",
"hard_mask": ["True", "False"],
"true_label_as_target": ["True", "False"],
"strategy": "topk",
"directed": ["True", "False"],
"params_list": "1,5,10,15,20,25,50,100",
"dataset": [
"cora",
Expand All @@ -36,19 +35,6 @@
"chameleon",
"squirrel"
],
"explainer_name": [
"random",
"distance",
"pagerank",
"sa",
"ig",
"gradcam",
"occlusion",
"basic_gnnexplainer",
"gnnexplainer",
"pgmexplainer",
"pgexplainer",
"subgraphx"
]
"explainer_name": ["random", "pgexplainer"]
}
}
28 changes: 0 additions & 28 deletions configs/real/topk/config_syn_topk_5expe_pg.json

This file was deleted.

13 changes: 1 addition & 12 deletions configs/real/topk/test.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,7 @@
"explain_graph": "False",
"hard_mask": "True",
"true_label_as_target": "True",
"dataset": [
"cora",
"pubmed",
"citeseer",
"facebook",
"cornell",
"texas",
"wisconsin",
"actor",
"chameleon",
"squirrel"
],
"dataset": ["facebook", "squirrel"],
"params_list": "1,5,10,15,20,25,50,100",
"explainer_name": "pgexplainer"
}
Expand Down
14 changes: 13 additions & 1 deletion configs/syn/topk/config_top_edges.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,18 @@
"seed": [0, 1, 2, 3, 4],
"top_acc": "True",
"dataset": ["syn1", "syn3", "syn4", "syn5", "syn6"],
"explainer_name": ["random", "pgexplainer"]
"explainer_name": [
"random",
"distance",
"pagerank",
"sa",
"ig",
"occlusion",
"basic_gnnexplainer",
"gnnexplainer",
"pgmexplainer",
"pgexplainer",
"subgraphx"
]
}
}
Loading

0 comments on commit 8f59c08

Please sign in to comment.