Skip to content

Commit

Permalink
[Example] Variational Graph Auto-Encoders (dmlc#2587)
Browse files Browse the repository at this point in the history
* [Example]Variational Graph Auto-Encoders

* change dgl dataset to single directional graph

* clean code

* refresh

Co-authored-by: Tianjun Xiao <[email protected]>
  • Loading branch information
yusun-nlp and sneakerkg authored Mar 2, 2021
1 parent d0638b1 commit f793864
Show file tree
Hide file tree
Showing 6 changed files with 603 additions and 5 deletions.
9 changes: 4 additions & 5 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ The folder contains example implementations of selected research papers related
| [Supervised Community Detection with Line Graph Neural Networks](#lgnn) | | | | | |
| [Text Generation from Knowledge Graphs with Graph Transformers](#graphwriter) | | | | | |
| [Link Prediction Based on Graph Neural Networks](#seal) | | :heavy_check_mark: | | :heavy_check_mark: | :heavy_check_mark: |

| [Variational Graph Auto-Encoders](#vgae) | | :heavy_check_mark: | | | |

## 2020

Expand Down Expand Up @@ -309,22 +309,21 @@ The folder contains example implementations of selected research papers related
- <a name="ggnn"></a> Li et al. Gated Graph Sequence Neural Networks. [Paper link](https://arxiv.org/abs/1511.05493).
- Example code: [PyTorch](../examples/pytorch/ggnn)
- Tags: question answering

- <a name="chebnet"></a> Defferrard et al. Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering. [Paper link](https://arxiv.org/abs/1606.09375).
- Example code: [PyTorch on image classification](../examples/pytorch/model_zoo/geometric), [PyTorch on node classification](../examples/pytorch/model_zoo/citation_network)
- Tags: image classification, graph classification, node classification

- <a name="monet"></a> Monti et al. Geometric deep learning on graphs and manifolds using mixture model CNNs. [Paper link](https://arxiv.org/abs/1611.08402).
- Example code: [PyTorch on image classification](../examples/pytorch/model_zoo/geometric), [PyTorch on node classification](../examples/pytorch/monet), [MXNet on node classification](../examples/mxnet/monet)
- Tags: image classification, graph classification, node classification

- <a name="weave"></a> Kearnes et al. Molecular Graph Convolutions: Moving Beyond Fingerprints. [Paper link](https://arxiv.org/abs/1603.00856).
- Example code: [PyTorch](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/moleculenet), [PyTorch for custom data](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/csv_data_configuration)
- Tags: molecular property prediction

- <a name="complex"></a> Trouillon et al. Complex Embeddings for Simple Link Prediction. [Paper link](http://proceedings.mlr.press/v48/trouillon16.pdf).
- Example code: [PyTorch](https://github.com/awslabs/dgl-ke/tree/master/examples), [PyTorch for custom data](https://aws-dglke.readthedocs.io/en/latest/commands.html)
- Tags: knowledge graph embedding
- <a name="vgae"></a> Thomas et al. Variational Graph Auto-Encoders. [Paper link](https://arxiv.org/abs/1611.07308).
- Example code: [PyTorch](../examples/pytorch/vgae)
- Tags: link prediction

## 2015

Expand Down
61 changes: 61 additions & 0 deletions examples/pytorch/vgae/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Variational Graph Auto-Encoders

- Paper link:https://arxiv.org/abs/1611.07308
- Author's code repo:https://github.com/tkipf/gae

## Requirements

- Pytorch
- Python 3.x
- DGL 0.6
- scikit-learn

## Run the demo

Run with following (available dataset: "cora", "citeseer", "pubmed")

```
python train.py
```

## Dataset

In this example, I use two kinds of data source. One from DGL's bulit-in dataset (CoraGraphDataset, CiteseerGraphDataset and PubmedGraphDataset), another from website https://github.com/kimiyoung/planetoid.

You can specify a dataset as follows:

```
python train.py --datasrc dgl --dataset cora // from DGL
python train.py --datasrc website --dataset cora // from website
```

**Note**: If you want to train by dataset from website, you should download folder https://github.com/kimiyoung/planetoid/tree/master/data. Then put it under project folder.

## Results

Use *area under the ROC curve* (AUC) and *average precision* (AP) scores for each model on the test set. Numbers show mean results and standard error for 10 runs with random initializations on fixed dataset splits.

### Dataset from DGL

| Dataset | AUC | AP |
| -------- | -------------- | ------------- |
| Cora | 91.8$\pm$ 0.01 | 92.5$\pm$0.01 |
| Citeseer | 89.2$\pm$0.02 | 90.8$\pm$0.01 |
| Pubmed | 94.5$\pm$0.01 | 94.6$\pm$0.01 |

### Dataset from website

| Dataset | AUC | AP |
| -------- | -------------- | -------------- |
| Cora | 90.9$\pm$ 0.01 | 92.1$\pm$0.01 |
| Citeseer | 90.3$\pm$0.01 | 91.8$\pm$0.01 |
| Pubmed | 94.4$\pm$ 0.01 | 94.6$\pm$ 0.01 |

### Reported results in paper

| Dataset | AUC | AP |
| -------- | -------------- | ------------- |
| Cora | 91.4$\pm$ 0.01 | 92.6$\pm$0.01 |
| Citeseer | 90.8$\pm$0.02 | 92.0$\pm$0.02 |
| Pubmed | 94.4$\pm$0.02 | 94.7$\pm$0.02 |

48 changes: 48 additions & 0 deletions examples/pytorch/vgae/input_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
'''
****************NOTE*****************
CREDITS : Thomas Kipf
since datasets are the same as those in kipf's implementation,
Their preprocessing source was used as-is.
*************************************
'''
import numpy as np
import sys
import pickle as pkl
import networkx as nx
import scipy.sparse as sp


def parse_index_file(filename):
index = []
for line in open(filename):
index.append(int(line.strip()))
return index


def load_data(dataset):
# load the data: x, tx, allx, graph
names = ['x', 'tx', 'allx', 'graph']
objects = []
for i in range(len(names)):
with open("data/ind.{}.{}".format(dataset, names[i]), 'rb') as f:
if sys.version_info > (3, 0):
objects.append(pkl.load(f, encoding='latin1'))
else:
objects.append(pkl.load(f))
x, tx, allx, graph = tuple(objects)
test_idx_reorder = parse_index_file("data/ind.{}.test.index".format(dataset))
test_idx_range = np.sort(test_idx_reorder)

if dataset == 'citeseer':
# Fix citeseer dataset (there are some isolated nodes in the graph)
# Find isolated nodes, add them as zero-vecs into the right position
test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder) + 1)
tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
tx_extended[test_idx_range - min(test_idx_range), :] = tx
tx = tx_extended

features = sp.vstack((allx, tx)).tolil()
features[test_idx_reorder, :] = features[test_idx_range, :]
adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph))

return adj, features
34 changes: 34 additions & 0 deletions examples/pytorch/vgae/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from dgl.nn.pytorch import GraphConv
import torch
import torch.nn as nn
import torch.nn.functional as F


class VGAEModel(nn.Module):
def __init__(self, in_dim, hidden1_dim, hidden2_dim):
super(VGAEModel, self).__init__()
self.in_dim = in_dim
self.hidden1_dim = hidden1_dim
self.hidden2_dim = hidden2_dim

layers = [GraphConv(self.in_dim, self.hidden1_dim, activation=F.relu, allow_zero_in_degree=True),
GraphConv(self.hidden1_dim, self.hidden2_dim, activation=lambda x: x, allow_zero_in_degree=True),
GraphConv(self.hidden1_dim, self.hidden2_dim, activation=lambda x: x, allow_zero_in_degree=True)]
self.layers = nn.ModuleList(layers)

def encoder(self, g, features):
h = self.layers[0](g, features)
self.mean = self.layers[1](g, h)
self.log_std = self.layers[2](g, h)
gaussian_noise = torch.randn(features.size(0), self.hidden2_dim)
sampled_z = self.mean + gaussian_noise * torch.exp(self.log_std)
return sampled_z

def decoder(self, z):
adj_rec = torch.sigmoid(torch.matmul(z, z.t()))
return adj_rec

def forward(self, g, features):
z = self.encoder(g, features)
adj_rec = self.decoder(z)
return adj_rec
169 changes: 169 additions & 0 deletions examples/pytorch/vgae/preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import numpy as np
import scipy.sparse as sp
import torch


def mask_test_edges(adj):
# Function to build test set with 10% positive links
# NOTE: Splits are randomized and results might slightly deviate from reported numbers in the paper.
# TODO: Clean up.

# Remove diagonal elements
adj = adj - sp.dia_matrix((adj.diagonal()[np.newaxis, :], [0]), shape=adj.shape)
adj.eliminate_zeros()
# Check that diag is zero:
assert np.diag(adj.todense()).sum() == 0

adj_triu = sp.triu(adj)
adj_tuple = sparse_to_tuple(adj_triu)
edges = adj_tuple[0]
edges_all = sparse_to_tuple(adj)[0]
num_test = int(np.floor(edges.shape[0] / 10.))
num_val = int(np.floor(edges.shape[0] / 20.))

all_edge_idx = list(range(edges.shape[0]))
np.random.shuffle(all_edge_idx)
val_edge_idx = all_edge_idx[:num_val]
test_edge_idx = all_edge_idx[num_val:(num_val + num_test)]
test_edges = edges[test_edge_idx]
val_edges = edges[val_edge_idx]
train_edges = np.delete(edges, np.hstack([test_edge_idx, val_edge_idx]), axis=0)

def ismember(a, b, tol=5):
rows_close = np.all(np.round(a - b[:, None], tol) == 0, axis=-1)
return np.any(rows_close)

test_edges_false = []
while len(test_edges_false) < len(test_edges):
idx_i = np.random.randint(0, adj.shape[0])
idx_j = np.random.randint(0, adj.shape[0])
if idx_i == idx_j:
continue
if ismember([idx_i, idx_j], edges_all):
continue
if test_edges_false:
if ismember([idx_j, idx_i], np.array(test_edges_false)):
continue
if ismember([idx_i, idx_j], np.array(test_edges_false)):
continue
test_edges_false.append([idx_i, idx_j])

val_edges_false = []
while len(val_edges_false) < len(val_edges):
idx_i = np.random.randint(0, adj.shape[0])
idx_j = np.random.randint(0, adj.shape[0])
if idx_i == idx_j:
continue
if ismember([idx_i, idx_j], train_edges):
continue
if ismember([idx_j, idx_i], train_edges):
continue
if ismember([idx_i, idx_j], val_edges):
continue
if ismember([idx_j, idx_i], val_edges):
continue
if val_edges_false:
if ismember([idx_j, idx_i], np.array(val_edges_false)):
continue
if ismember([idx_i, idx_j], np.array(val_edges_false)):
continue
val_edges_false.append([idx_i, idx_j])

assert ~ismember(test_edges_false, edges_all)
assert ~ismember(val_edges_false, edges_all)
assert ~ismember(val_edges, train_edges)
assert ~ismember(test_edges, train_edges)
assert ~ismember(val_edges, test_edges)

data = np.ones(train_edges.shape[0])

# Re-build adj matrix
adj_train = sp.csr_matrix((data, (train_edges[:, 0], train_edges[:, 1])), shape=adj.shape)
adj_train = adj_train + adj_train.T

# NOTE: these edge lists only contain single direction of edge!
return adj_train, train_edges, val_edges, val_edges_false, test_edges, test_edges_false


def mask_test_edges_dgl(graph, adj):
src, dst = graph.edges()
edges_all = torch.stack([src, dst], dim=0)
edges_all = edges_all.t().numpy()
num_test = int(np.floor(edges_all.shape[0] / 10.))
num_val = int(np.floor(edges_all.shape[0] / 20.))

all_edge_idx = list(range(edges_all.shape[0]))
np.random.shuffle(all_edge_idx)
val_edge_idx = all_edge_idx[:num_val]
test_edge_idx = all_edge_idx[num_val:(num_val + num_test)]
train_edge_idx = all_edge_idx[(num_val + num_test):]
test_edges = edges_all[test_edge_idx]
val_edges = edges_all[val_edge_idx]
train_edges = np.delete(edges_all, np.hstack([test_edge_idx, val_edge_idx]), axis=0)

def ismember(a, b, tol=5):
rows_close = np.all(np.round(a - b[:, None], tol) == 0, axis=-1)
return np.any(rows_close)

test_edges_false = []
while len(test_edges_false) < len(test_edges):
idx_i = np.random.randint(0, adj.shape[0])
idx_j = np.random.randint(0, adj.shape[0])
if idx_i == idx_j:
continue
if ismember([idx_i, idx_j], edges_all):
continue
if test_edges_false:
if ismember([idx_j, idx_i], np.array(test_edges_false)):
continue
if ismember([idx_i, idx_j], np.array(test_edges_false)):
continue
test_edges_false.append([idx_i, idx_j])

val_edges_false = []
while len(val_edges_false) < len(val_edges):
idx_i = np.random.randint(0, adj.shape[0])
idx_j = np.random.randint(0, adj.shape[0])
if idx_i == idx_j:
continue
if ismember([idx_i, idx_j], train_edges):
continue
if ismember([idx_j, idx_i], train_edges):
continue
if ismember([idx_i, idx_j], val_edges):
continue
if ismember([idx_j, idx_i], val_edges):
continue
if val_edges_false:
if ismember([idx_j, idx_i], np.array(val_edges_false)):
continue
if ismember([idx_i, idx_j], np.array(val_edges_false)):
continue
val_edges_false.append([idx_i, idx_j])

assert ~ismember(test_edges_false, edges_all)
assert ~ismember(val_edges_false, edges_all)
assert ~ismember(val_edges, train_edges)
assert ~ismember(test_edges, train_edges)
assert ~ismember(val_edges, test_edges)

# NOTE: these edge lists only contain single direction of edge!
return train_edge_idx, val_edges, val_edges_false, test_edges, test_edges_false


def sparse_to_tuple(sparse_mx):
if not sp.isspmatrix_coo(sparse_mx):
sparse_mx = sparse_mx.tocoo()
coords = np.vstack((sparse_mx.row, sparse_mx.col)).transpose()
values = sparse_mx.data
shape = sparse_mx.shape
return coords, values, shape


def preprocess_graph(adj):
adj = sp.coo_matrix(adj)
adj_ = adj + sp.eye(adj.shape[0])
rowsum = np.array(adj_.sum(1))
degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5).flatten())
adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo()
return adj_normalized, sparse_to_tuple(adj_normalized)
Loading

0 comments on commit f793864

Please sign in to comment.