-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Example] Add implementation of mvgrl (dmlc#2739)
* [Example ]add mvgrl * [Doc] add mvgrl to readme * add more comments * fix typos * replace tab with space * [doc] replace tab with space * [Doc] fix a typo * fix minor typos * fix typos * fix typos * fix typos * fix typos * fix Co-authored-by: Mufei Li <[email protected]>
- Loading branch information
1 parent
e6f6c2e
commit bcffdb8
Showing
11 changed files
with
1,153 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
# DGL Implementation of MVGRL | ||
This DGL example implements the model proposed in the paper [Contrastive Multi-View Representation Learning on Graphs](https://arxiv.org/abs/2006.05582). | ||
|
||
Author's code: https://github.com/kavehhassani/mvgrl | ||
|
||
## Example Implementor | ||
|
||
This example was implemented by [Hengrui Zhang](https://github.com/hengruizhang98) when he was an applied scientist intern at AWS Shanghai AI Lab. | ||
|
||
## Dependencies | ||
|
||
- Python 3.7 | ||
- PyTorch 1.7.1 | ||
- dgl 0.6.0 | ||
- networkx | ||
- scipy | ||
|
||
## Datasets | ||
|
||
##### Unsupervised Graph Classification Datasets: | ||
|
||
'MUTAG', 'PTC_MR', 'REDDIT-BINARY', 'IMDB-BINARY', 'IMDB-MULTI'. | ||
|
||
| Dataset | MUTAG | PTC_MR | RDT-B | IMDB-B | IMDB-M | | ||
| --------------- | ----- | ------ | ------ | ------ | ------ | | ||
| # Graphs | 188 | 344 | 2000 | 1000 | 1500 | | ||
| # Classes | 2 | 2 | 2 | 2 | 3 | | ||
| Avg. Graph Size | 17.93 | 14.29 | 429.63 | 19.77 | 13.00 | | ||
* RDT-B, IMDB-B, IMDB-M are short for REDDIT-BINARY, IMDB-BINARY and IMDB-MULTI respectively. | ||
|
||
##### Unsupervised Node Classification Datasets: | ||
|
||
'Cora', 'Citeseer' and 'Pubmed' | ||
|
||
| Dataset | # Nodes | # Edges | # Classes | | ||
| -------- | ------- | ------- | --------- | | ||
| Cora | 2,708 | 10,556 | 7 | | ||
| Citeseer | 3,327 | 9,228 | 6 | | ||
| Pubmed | 19,717 | 88,651 | 3 | | ||
|
||
|
||
## Arguments | ||
|
||
##### Graph Classification: | ||
|
||
``` | ||
--dataname str The graph dataset name. Default is 'MUTAG'. | ||
--gpu int GPU index. Default is -1, using cpu. | ||
--epochs int Number of training periods. Default is 200. | ||
--patience int Early stopping steps. Default is 20. | ||
--lr float Learning rate. Default is 0.001. | ||
--wd float Weight decay. Default is 0.0. | ||
--batch_size int Size of a training batch. Default is 64. | ||
--n_layers int Number of GNN layers. Default is 4. | ||
--hid_dim int Embedding dimension. Default is 32. | ||
``` | ||
|
||
##### Node Classification: | ||
|
||
``` | ||
--dataname str The graph dataset name. Default is 'cora'. | ||
--gpu int GPU index. Default is -1, using cpu. | ||
--epochs int Number of training periods. Default is 500. | ||
--patience int Early stopping steps. Default is 20. | ||
--lr1 float Learning rate of main model. Default is 0.001. | ||
--lr2 float Learning rate of linear classifer. Default is 0.01. | ||
--wd1 float Weight decay of main model. Default is 0.0. | ||
--wd2 float Weight decay of linear classifier. Default is 0.0. | ||
--epsilon float Edge mask threshold. Default is 0.01. | ||
--hid_dim int Embedding dimension. Default is 512. | ||
``` | ||
|
||
## How to run examples | ||
|
||
###### Graph Classification | ||
|
||
```python | ||
# Enter the 'graph' directory | ||
cd graph | ||
|
||
# MUTAG: | ||
python main.py --dataname MUTAG --epochs 20 | ||
|
||
# PTC_MR: | ||
python main.py --dataname PTC_MR --epochs 32 --hid_dim 128 | ||
|
||
# REDDIT-BINARY | ||
python main.py --dataname REDDIT-BINARY --epochs 20 --hid_dim 128 | ||
|
||
# IMDB-BINARY | ||
python main.py --dataname IMDB-BINARY --epochs 20 --hid_dim 512 --n_layers 2 | ||
|
||
# IMDB-MULTI | ||
python main.py --dataname IMDB-MULTI --epochs 20 --hid_dim 512 --n_layers 2 | ||
``` | ||
###### Node Classification | ||
|
||
For semi-supervised node classification on 'Cora', 'Citeseer' and 'Pubmed', we provide two implementations: | ||
|
||
1. full-graph training, see 'main.py', where we contrast the local and global representations of the whole graph. | ||
2. subgraph training, see 'main_sample.py', where we contrast the local and global representations of a sampled subgraph with fixed number of nodes. | ||
|
||
For larger graphs(e.g. Pubmed), it would be hard to calculate the graph diffusion matrix(i.e., PPR matrix), so we try to approximate it with [APPNP](https://arxiv.org/abs/1810.05997), see function 'process_dataset_appnp' in 'node/dataset.py' for details. | ||
|
||
```python | ||
# Enter the 'node' directory | ||
cd node | ||
|
||
# Cora with full graph | ||
python main.py --dataname cora --gpu 0 | ||
|
||
# Cora with sampled subgraphs | ||
python main_sample.py --dataname cora --gpu 0 | ||
|
||
# Citeseer with full graph | ||
python main.py --dataname citeseer --wd1 0.001 --wd2 0.01 --epochs 200 --gpu 0 | ||
|
||
# Citeseer with sampled subgraphs | ||
python main_sample.py --dataname citeseer --wd2 0.01 --gpu 0 | ||
|
||
# Pubmed with sampled subgraphs | ||
python main_sample.py --dataname pubmed --sample_size 4000 --epochs 400 --patience 999 --gpu 0 | ||
``` | ||
|
||
## Performance | ||
|
||
We use the same hyper-parameter settings as stated in the original paper. | ||
|
||
##### Graph classification: | ||
|
||
| Dataset | MUTAG | PTC-MR | REDDIT-B | IMDB-B | IMDB-M | | ||
| :---------------: | :---: | :----: | :------: | :----: | :----: | | ||
| Accuracy Reported | 89.7 | 62.5 | 84.5 | 74.2 | 51.2 | | ||
| DGL | 89.4 | 62.2 | 85.0 | 73.8 | 51.1 | | ||
|
||
* The datasets that the authors used are slightly different from standard TUDataset (see dgl.data.GINDataset) in the nodes' features(e.g. The node features of 'MUTAG' dataset are of dimensionality 11 rather than 7") | ||
|
||
##### Node classification: | ||
|
||
| Dataset | Cora | Citeseer | Pubmed | | ||
| :---------------: | :--: | :------: | :----: | | ||
| Accuracy Reported | 86.8 | 73.3 | 80.1 | | ||
| DGL-sample | 83.2 | 72.6 | 79.8 | | ||
| DGL-full | 83.5 | 73.7 | OOM | | ||
|
||
* We fail to reproduce the reported accuracy on 'Cora', even with the authors' code. | ||
* The accuracy reported by the original paper is based on fixed-sized subgraph-training. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
''' Code adapted from https://github.com/kavehhassani/mvgrl ''' | ||
import os | ||
import re | ||
import numpy as np | ||
import dgl | ||
import torch as th | ||
import networkx as nx | ||
from dgl.data import DGLDataset | ||
from collections import Counter | ||
from scipy.linalg import fractional_matrix_power, inv | ||
|
||
''' Compute Personalized Page Ranking''' | ||
def compute_ppr(graph: nx.Graph, alpha=0.2, self_loop=True): | ||
a = nx.convert_matrix.to_numpy_array(graph) | ||
if self_loop: | ||
a = a + np.eye(a.shape[0]) # A^ = A + I_n | ||
d = np.diag(np.sum(a, 1)) # D^ = Sigma A^_ii | ||
dinv = fractional_matrix_power(d, -0.5) # D^(-1/2) | ||
at = np.matmul(np.matmul(dinv, a), dinv) # A~ = D^(-1/2) x A^ x D^(-1/2) | ||
return alpha * inv((np.eye(a.shape[0]) - (1 - alpha) * at)) # a(I_n-(1-a)A~)^-1 | ||
|
||
|
||
def download(dataset, datadir): | ||
os.makedirs(datadir) | ||
url = 'https://ls11-www.cs.tu-dortmund.de/people/morris/graphkerneldatasets/{0}.zip'.format(dataset) | ||
zipfile = os.path.basename(url) | ||
os.system('wget {0}; unzip {1}'.format(url, zipfile)) | ||
os.system('mv {0}/* {1}'.format(dataset, datadir)) | ||
os.system('rm -r {0}'.format(dataset)) | ||
os.system('rm {0}'.format(zipfile)) | ||
|
||
def process(dataset): | ||
src = os.path.join(os.path.dirname(__file__), 'data') | ||
prefix = os.path.join(src, dataset, dataset) | ||
|
||
# assign each node to the corresponding graph | ||
graph_node_dict = {} | ||
with open('{0}_graph_indicator.txt'.format(prefix), 'r') as f: | ||
for idx, line in enumerate(f): | ||
graph_node_dict[idx + 1] = int(line.strip('\n')) | ||
|
||
node_labels = [] | ||
if os.path.exists('{0}_node_labels.txt'.format(prefix)): | ||
with open('{0}_node_labels.txt'.format(prefix), 'r') as f: | ||
for line in f: | ||
node_labels += [int(line.strip('\n')) - 1] | ||
num_unique_node_labels = max(node_labels) + 1 | ||
else: | ||
print('No node labels') | ||
|
||
node_attrs = [] | ||
if os.path.exists('{0}_node_attributes.txt'.format(prefix)): | ||
with open('{0}_node_attributes.txt'.format(prefix), 'r') as f: | ||
for line in f: | ||
node_attrs.append( | ||
np.array([float(attr) for attr in re.split("[,\s]+", line.strip("\s\n")) if attr], dtype=np.float) | ||
) | ||
else: | ||
print('No node attributes') | ||
|
||
graph_labels = [] | ||
unique_labels = set() | ||
with open('{0}_graph_labels.txt'.format(prefix), 'r') as f: | ||
for line in f: | ||
val = int(line.strip('\n')) | ||
if val not in unique_labels: | ||
unique_labels.add(val) | ||
graph_labels.append(val) | ||
label_idx_dict = {val: idx for idx, val in enumerate(unique_labels)} | ||
graph_labels = np.array([label_idx_dict[l] for l in graph_labels]) | ||
|
||
adj_list = {idx: [] for idx in range(1, len(graph_labels) + 1)} | ||
index_graph = {idx: [] for idx in range(1, len(graph_labels) + 1)} | ||
with open('{0}_A.txt'.format(prefix), 'r') as f: | ||
for line in f: | ||
u, v = tuple(map(int, line.strip('\n').split(','))) | ||
adj_list[graph_node_dict[u]].append((u, v)) | ||
index_graph[graph_node_dict[u]] += [u, v] | ||
|
||
for k in index_graph.keys(): | ||
index_graph[k] = [u - 1 for u in set(index_graph[k])] | ||
|
||
graphs, pprs = [], [] | ||
for idx in range(1, 1 + len(adj_list)): | ||
graph = nx.from_edgelist(adj_list[idx]) | ||
|
||
graph.graph['label'] = graph_labels[idx - 1] | ||
for u in graph.nodes(): | ||
if len(node_labels) > 0: | ||
node_label_one_hot = [0] * num_unique_node_labels | ||
node_label = node_labels[u - 1] | ||
node_label_one_hot[node_label] = 1 | ||
graph.nodes[u]['label'] = node_label_one_hot | ||
if len(node_attrs) > 0: | ||
graph.nodes[u]['feat'] = node_attrs[u - 1] | ||
if len(node_attrs) > 0: | ||
graph.graph['feat_dim'] = node_attrs[0].shape[0] | ||
|
||
# relabeling | ||
mapping = {} | ||
for node_idx, node in enumerate(graph.nodes()): | ||
mapping[node] = node_idx | ||
|
||
graphs.append(nx.relabel_nodes(graph, mapping)) | ||
pprs.append(compute_ppr(graph, alpha=0.2)) | ||
|
||
if 'feat_dim' in graphs[0].graph: | ||
pass | ||
else: | ||
max_deg = max([max(dict(graph.degree).values()) for graph in graphs]) | ||
for graph in graphs: | ||
for u in graph.nodes(data=True): | ||
f = np.zeros(max_deg + 1) | ||
f[graph.degree[u[0]]] = 1.0 | ||
if 'label' in u[1]: | ||
f = np.concatenate((np.array(u[1]['label'], dtype=np.float), f)) | ||
graph.nodes[u[0]]['feat'] = f | ||
return graphs, pprs | ||
|
||
def load(dataset): | ||
|
||
basedir = os.path.dirname(os.path.abspath(__file__)) | ||
datadir = os.path.join(basedir, 'data', dataset) | ||
|
||
if not os.path.exists(datadir): | ||
download(dataset, datadir) | ||
graphs, diff = process(dataset) | ||
feat, adj, labels = [], [], [] | ||
|
||
for idx, graph in enumerate(graphs): | ||
adj.append(nx.to_numpy_array(graph)) | ||
labels.append(graph.graph['label']) | ||
feat.append(np.array(list(nx.get_node_attributes(graph, 'feat').values()))) | ||
|
||
adj, diff, feat, labels = np.array(adj), np.array(diff), np.array(feat), np.array(labels) | ||
|
||
np.save(f'{datadir}/adj.npy', adj) | ||
np.save(f'{datadir}/diff.npy', diff) | ||
np.save(f'{datadir}/feat.npy', feat) | ||
np.save(f'{datadir}/labels.npy', labels) | ||
else: | ||
adj = np.load(f'{datadir}/adj.npy', allow_pickle=True) | ||
diff = np.load(f'{datadir}/diff.npy', allow_pickle=True) | ||
feat = np.load(f'{datadir}/feat.npy', allow_pickle=True) | ||
labels = np.load(f'{datadir}/labels.npy', allow_pickle=True) | ||
|
||
n_graphs = adj.shape[0] | ||
|
||
graphs = [] | ||
diff_graphs = [] | ||
lbls = [] | ||
|
||
for i in range(n_graphs): | ||
a = adj[i] | ||
edge_indexes = a.nonzero() | ||
|
||
graph = dgl.graph(edge_indexes) | ||
graph = graph.add_self_loop() | ||
graph.ndata['feat'] = th.tensor(feat[i]).float() | ||
|
||
diff_adj = diff[i] | ||
diff_indexes = diff_adj.nonzero() | ||
diff_weight = th.tensor(diff_adj[diff_indexes]).float() | ||
|
||
diff_graph = dgl.graph(diff_indexes) | ||
diff_graph.edata['edge_weight'] = diff_weight | ||
label = labels[i] | ||
graphs.append(graph) | ||
diff_graphs.append(diff_graph) | ||
lbls.append(label) | ||
|
||
labels = th.tensor(lbls) | ||
|
||
dataset = TUDataset(graphs, diff_graphs, labels) | ||
return dataset | ||
|
||
class TUDataset(DGLDataset): | ||
def __init__(self, graphs, diff_graphs, labels): | ||
super(TUDataset, self).__init__(name='tu') | ||
self.graphs = graphs | ||
self.diff_graphs = diff_graphs | ||
self.labels = labels | ||
|
||
def process(self): | ||
return | ||
|
||
def __len__(self): | ||
return len(self.graphs) | ||
|
||
def __getitem__(self, idx): | ||
return self.graphs[idx], self.diff_graphs[idx], self.labels[idx] |
Oops, something went wrong.