-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Example] Variational Graph Auto-Encoders (dmlc#2587)
* [Example]Variational Graph Auto-Encoders * change dgl dataset to single directional graph * clean code * refresh Co-authored-by: Tianjun Xiao <[email protected]>
- Loading branch information
Showing
6 changed files
with
603 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Variational Graph Auto-Encoders | ||
|
||
- Paper link:https://arxiv.org/abs/1611.07308 | ||
- Author's code repo:https://github.com/tkipf/gae | ||
|
||
## Requirements | ||
|
||
- Pytorch | ||
- Python 3.x | ||
- DGL 0.6 | ||
- scikit-learn | ||
|
||
## Run the demo | ||
|
||
Run with following (available dataset: "cora", "citeseer", "pubmed") | ||
|
||
``` | ||
python train.py | ||
``` | ||
|
||
## Dataset | ||
|
||
In this example, I use two kinds of data source. One from DGL's bulit-in dataset (CoraGraphDataset, CiteseerGraphDataset and PubmedGraphDataset), another from website https://github.com/kimiyoung/planetoid. | ||
|
||
You can specify a dataset as follows: | ||
|
||
``` | ||
python train.py --datasrc dgl --dataset cora // from DGL | ||
python train.py --datasrc website --dataset cora // from website | ||
``` | ||
|
||
**Note**: If you want to train by dataset from website, you should download folder https://github.com/kimiyoung/planetoid/tree/master/data. Then put it under project folder. | ||
|
||
## Results | ||
|
||
Use *area under the ROC curve* (AUC) and *average precision* (AP) scores for each model on the test set. Numbers show mean results and standard error for 10 runs with random initializations on fixed dataset splits. | ||
|
||
### Dataset from DGL | ||
|
||
| Dataset | AUC | AP | | ||
| -------- | -------------- | ------------- | | ||
| Cora | 91.8$\pm$ 0.01 | 92.5$\pm$0.01 | | ||
| Citeseer | 89.2$\pm$0.02 | 90.8$\pm$0.01 | | ||
| Pubmed | 94.5$\pm$0.01 | 94.6$\pm$0.01 | | ||
|
||
### Dataset from website | ||
|
||
| Dataset | AUC | AP | | ||
| -------- | -------------- | -------------- | | ||
| Cora | 90.9$\pm$ 0.01 | 92.1$\pm$0.01 | | ||
| Citeseer | 90.3$\pm$0.01 | 91.8$\pm$0.01 | | ||
| Pubmed | 94.4$\pm$ 0.01 | 94.6$\pm$ 0.01 | | ||
|
||
### Reported results in paper | ||
|
||
| Dataset | AUC | AP | | ||
| -------- | -------------- | ------------- | | ||
| Cora | 91.4$\pm$ 0.01 | 92.6$\pm$0.01 | | ||
| Citeseer | 90.8$\pm$0.02 | 92.0$\pm$0.02 | | ||
| Pubmed | 94.4$\pm$0.02 | 94.7$\pm$0.02 | | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
''' | ||
****************NOTE***************** | ||
CREDITS : Thomas Kipf | ||
since datasets are the same as those in kipf's implementation, | ||
Their preprocessing source was used as-is. | ||
************************************* | ||
''' | ||
import numpy as np | ||
import sys | ||
import pickle as pkl | ||
import networkx as nx | ||
import scipy.sparse as sp | ||
|
||
|
||
def parse_index_file(filename): | ||
index = [] | ||
for line in open(filename): | ||
index.append(int(line.strip())) | ||
return index | ||
|
||
|
||
def load_data(dataset): | ||
# load the data: x, tx, allx, graph | ||
names = ['x', 'tx', 'allx', 'graph'] | ||
objects = [] | ||
for i in range(len(names)): | ||
with open("data/ind.{}.{}".format(dataset, names[i]), 'rb') as f: | ||
if sys.version_info > (3, 0): | ||
objects.append(pkl.load(f, encoding='latin1')) | ||
else: | ||
objects.append(pkl.load(f)) | ||
x, tx, allx, graph = tuple(objects) | ||
test_idx_reorder = parse_index_file("data/ind.{}.test.index".format(dataset)) | ||
test_idx_range = np.sort(test_idx_reorder) | ||
|
||
if dataset == 'citeseer': | ||
# Fix citeseer dataset (there are some isolated nodes in the graph) | ||
# Find isolated nodes, add them as zero-vecs into the right position | ||
test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder) + 1) | ||
tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1])) | ||
tx_extended[test_idx_range - min(test_idx_range), :] = tx | ||
tx = tx_extended | ||
|
||
features = sp.vstack((allx, tx)).tolil() | ||
features[test_idx_reorder, :] = features[test_idx_range, :] | ||
adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph)) | ||
|
||
return adj, features |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from dgl.nn.pytorch import GraphConv | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class VGAEModel(nn.Module): | ||
def __init__(self, in_dim, hidden1_dim, hidden2_dim): | ||
super(VGAEModel, self).__init__() | ||
self.in_dim = in_dim | ||
self.hidden1_dim = hidden1_dim | ||
self.hidden2_dim = hidden2_dim | ||
|
||
layers = [GraphConv(self.in_dim, self.hidden1_dim, activation=F.relu, allow_zero_in_degree=True), | ||
GraphConv(self.hidden1_dim, self.hidden2_dim, activation=lambda x: x, allow_zero_in_degree=True), | ||
GraphConv(self.hidden1_dim, self.hidden2_dim, activation=lambda x: x, allow_zero_in_degree=True)] | ||
self.layers = nn.ModuleList(layers) | ||
|
||
def encoder(self, g, features): | ||
h = self.layers[0](g, features) | ||
self.mean = self.layers[1](g, h) | ||
self.log_std = self.layers[2](g, h) | ||
gaussian_noise = torch.randn(features.size(0), self.hidden2_dim) | ||
sampled_z = self.mean + gaussian_noise * torch.exp(self.log_std) | ||
return sampled_z | ||
|
||
def decoder(self, z): | ||
adj_rec = torch.sigmoid(torch.matmul(z, z.t())) | ||
return adj_rec | ||
|
||
def forward(self, g, features): | ||
z = self.encoder(g, features) | ||
adj_rec = self.decoder(z) | ||
return adj_rec |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
import numpy as np | ||
import scipy.sparse as sp | ||
import torch | ||
|
||
|
||
def mask_test_edges(adj): | ||
# Function to build test set with 10% positive links | ||
# NOTE: Splits are randomized and results might slightly deviate from reported numbers in the paper. | ||
# TODO: Clean up. | ||
|
||
# Remove diagonal elements | ||
adj = adj - sp.dia_matrix((adj.diagonal()[np.newaxis, :], [0]), shape=adj.shape) | ||
adj.eliminate_zeros() | ||
# Check that diag is zero: | ||
assert np.diag(adj.todense()).sum() == 0 | ||
|
||
adj_triu = sp.triu(adj) | ||
adj_tuple = sparse_to_tuple(adj_triu) | ||
edges = adj_tuple[0] | ||
edges_all = sparse_to_tuple(adj)[0] | ||
num_test = int(np.floor(edges.shape[0] / 10.)) | ||
num_val = int(np.floor(edges.shape[0] / 20.)) | ||
|
||
all_edge_idx = list(range(edges.shape[0])) | ||
np.random.shuffle(all_edge_idx) | ||
val_edge_idx = all_edge_idx[:num_val] | ||
test_edge_idx = all_edge_idx[num_val:(num_val + num_test)] | ||
test_edges = edges[test_edge_idx] | ||
val_edges = edges[val_edge_idx] | ||
train_edges = np.delete(edges, np.hstack([test_edge_idx, val_edge_idx]), axis=0) | ||
|
||
def ismember(a, b, tol=5): | ||
rows_close = np.all(np.round(a - b[:, None], tol) == 0, axis=-1) | ||
return np.any(rows_close) | ||
|
||
test_edges_false = [] | ||
while len(test_edges_false) < len(test_edges): | ||
idx_i = np.random.randint(0, adj.shape[0]) | ||
idx_j = np.random.randint(0, adj.shape[0]) | ||
if idx_i == idx_j: | ||
continue | ||
if ismember([idx_i, idx_j], edges_all): | ||
continue | ||
if test_edges_false: | ||
if ismember([idx_j, idx_i], np.array(test_edges_false)): | ||
continue | ||
if ismember([idx_i, idx_j], np.array(test_edges_false)): | ||
continue | ||
test_edges_false.append([idx_i, idx_j]) | ||
|
||
val_edges_false = [] | ||
while len(val_edges_false) < len(val_edges): | ||
idx_i = np.random.randint(0, adj.shape[0]) | ||
idx_j = np.random.randint(0, adj.shape[0]) | ||
if idx_i == idx_j: | ||
continue | ||
if ismember([idx_i, idx_j], train_edges): | ||
continue | ||
if ismember([idx_j, idx_i], train_edges): | ||
continue | ||
if ismember([idx_i, idx_j], val_edges): | ||
continue | ||
if ismember([idx_j, idx_i], val_edges): | ||
continue | ||
if val_edges_false: | ||
if ismember([idx_j, idx_i], np.array(val_edges_false)): | ||
continue | ||
if ismember([idx_i, idx_j], np.array(val_edges_false)): | ||
continue | ||
val_edges_false.append([idx_i, idx_j]) | ||
|
||
assert ~ismember(test_edges_false, edges_all) | ||
assert ~ismember(val_edges_false, edges_all) | ||
assert ~ismember(val_edges, train_edges) | ||
assert ~ismember(test_edges, train_edges) | ||
assert ~ismember(val_edges, test_edges) | ||
|
||
data = np.ones(train_edges.shape[0]) | ||
|
||
# Re-build adj matrix | ||
adj_train = sp.csr_matrix((data, (train_edges[:, 0], train_edges[:, 1])), shape=adj.shape) | ||
adj_train = adj_train + adj_train.T | ||
|
||
# NOTE: these edge lists only contain single direction of edge! | ||
return adj_train, train_edges, val_edges, val_edges_false, test_edges, test_edges_false | ||
|
||
|
||
def mask_test_edges_dgl(graph, adj): | ||
src, dst = graph.edges() | ||
edges_all = torch.stack([src, dst], dim=0) | ||
edges_all = edges_all.t().numpy() | ||
num_test = int(np.floor(edges_all.shape[0] / 10.)) | ||
num_val = int(np.floor(edges_all.shape[0] / 20.)) | ||
|
||
all_edge_idx = list(range(edges_all.shape[0])) | ||
np.random.shuffle(all_edge_idx) | ||
val_edge_idx = all_edge_idx[:num_val] | ||
test_edge_idx = all_edge_idx[num_val:(num_val + num_test)] | ||
train_edge_idx = all_edge_idx[(num_val + num_test):] | ||
test_edges = edges_all[test_edge_idx] | ||
val_edges = edges_all[val_edge_idx] | ||
train_edges = np.delete(edges_all, np.hstack([test_edge_idx, val_edge_idx]), axis=0) | ||
|
||
def ismember(a, b, tol=5): | ||
rows_close = np.all(np.round(a - b[:, None], tol) == 0, axis=-1) | ||
return np.any(rows_close) | ||
|
||
test_edges_false = [] | ||
while len(test_edges_false) < len(test_edges): | ||
idx_i = np.random.randint(0, adj.shape[0]) | ||
idx_j = np.random.randint(0, adj.shape[0]) | ||
if idx_i == idx_j: | ||
continue | ||
if ismember([idx_i, idx_j], edges_all): | ||
continue | ||
if test_edges_false: | ||
if ismember([idx_j, idx_i], np.array(test_edges_false)): | ||
continue | ||
if ismember([idx_i, idx_j], np.array(test_edges_false)): | ||
continue | ||
test_edges_false.append([idx_i, idx_j]) | ||
|
||
val_edges_false = [] | ||
while len(val_edges_false) < len(val_edges): | ||
idx_i = np.random.randint(0, adj.shape[0]) | ||
idx_j = np.random.randint(0, adj.shape[0]) | ||
if idx_i == idx_j: | ||
continue | ||
if ismember([idx_i, idx_j], train_edges): | ||
continue | ||
if ismember([idx_j, idx_i], train_edges): | ||
continue | ||
if ismember([idx_i, idx_j], val_edges): | ||
continue | ||
if ismember([idx_j, idx_i], val_edges): | ||
continue | ||
if val_edges_false: | ||
if ismember([idx_j, idx_i], np.array(val_edges_false)): | ||
continue | ||
if ismember([idx_i, idx_j], np.array(val_edges_false)): | ||
continue | ||
val_edges_false.append([idx_i, idx_j]) | ||
|
||
assert ~ismember(test_edges_false, edges_all) | ||
assert ~ismember(val_edges_false, edges_all) | ||
assert ~ismember(val_edges, train_edges) | ||
assert ~ismember(test_edges, train_edges) | ||
assert ~ismember(val_edges, test_edges) | ||
|
||
# NOTE: these edge lists only contain single direction of edge! | ||
return train_edge_idx, val_edges, val_edges_false, test_edges, test_edges_false | ||
|
||
|
||
def sparse_to_tuple(sparse_mx): | ||
if not sp.isspmatrix_coo(sparse_mx): | ||
sparse_mx = sparse_mx.tocoo() | ||
coords = np.vstack((sparse_mx.row, sparse_mx.col)).transpose() | ||
values = sparse_mx.data | ||
shape = sparse_mx.shape | ||
return coords, values, shape | ||
|
||
|
||
def preprocess_graph(adj): | ||
adj = sp.coo_matrix(adj) | ||
adj_ = adj + sp.eye(adj.shape[0]) | ||
rowsum = np.array(adj_.sum(1)) | ||
degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5).flatten()) | ||
adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo() | ||
return adj_normalized, sparse_to_tuple(adj_normalized) |
Oops, something went wrong.