Skip to content

Commit

Permalink
[bugfix] Fix bugs in vgae (dmlc#2727)
Browse files Browse the repository at this point in the history
* [Example]Variational Graph Auto-Encoders

* change dgl dataset to single directional graph

* clean code

* refresh

* fix bug

* fix bug

* fix bug

* add gpu

Co-authored-by: Tianjun Xiao <[email protected]>
Co-authored-by: Minjie Wang <[email protected]>
  • Loading branch information
3 people authored Apr 3, 2021
1 parent b2e35e6 commit cba5af2
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 12 deletions.
6 changes: 4 additions & 2 deletions examples/pytorch/vgae/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch.nn as nn
import torch.nn.functional as F

from train import device


class VGAEModel(nn.Module):
def __init__(self, in_dim, hidden1_dim, hidden2_dim):
Expand All @@ -20,8 +22,8 @@ def encoder(self, g, features):
h = self.layers[0](g, features)
self.mean = self.layers[1](g, h)
self.log_std = self.layers[2](g, h)
gaussian_noise = torch.randn(features.size(0), self.hidden2_dim)
sampled_z = self.mean + gaussian_noise * torch.exp(self.log_std)
gaussian_noise = torch.randn(features.size(0), self.hidden2_dim).to(device)
sampled_z = self.mean + gaussian_noise * torch.exp(self.log_std).to(device)
return sampled_z

def decoder(self, z):
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/vgae/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def ismember(a, b, tol=5):
def mask_test_edges_dgl(graph, adj):
src, dst = graph.edges()
edges_all = torch.stack([src, dst], dim=0)
edges_all = edges_all.t().numpy()
edges_all = edges_all.t().cpu().numpy()
num_test = int(np.floor(edges_all.shape[0] / 10.))
num_val = int(np.floor(edges_all.shape[0] / 20.))

Expand Down
20 changes: 11 additions & 9 deletions examples/pytorch/vgae/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,23 @@
parser.add_argument('--hidden2', '-h2', type=int, default=16, help='Number of units in hidden layer 2.')
parser.add_argument('--datasrc', '-s', type=str, default='dgl',
help='Dataset download from dgl Dataset or website.')
parser.add_argument('--dataset', '-d', type=str, default='pubmed', help='Dataset string.')
parser.add_argument('--dataset', '-d', type=str, default='cora', help='Dataset string.')
parser.add_argument('--gpu_id', type=int, default=0, help='GPU id to use.')
args = parser.parse_args()


# check device
device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else "cpu")
# device = "cpu"

# roc_means = []
# ap_means = []

def compute_loss_para(adj):
pos_weight = ((adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum())
norm = adj.shape[0] * adj.shape[0] / float((adj.shape[0] * adj.shape[0] - adj.sum()) * 2)
weight_mask = adj.view(-1) == 1
weight_tensor = torch.ones(weight_mask.size(0))
weight_tensor = torch.ones(weight_mask.size(0)).to(device)
weight_tensor[weight_mask] = pos_weight
return weight_tensor, norm

Expand All @@ -51,6 +55,7 @@ def get_scores(edges_pos, edges_neg, adj_rec):
def sigmoid(x):
return 1 / (1 + np.exp(-x))

adj_rec = adj_rec.cpu()
# Predict on test set of edges
preds = []
for e in edges_pos:
Expand Down Expand Up @@ -80,21 +85,18 @@ def dgl_main():
raise NotImplementedError
graph = dataset[0]

# check device
device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else "cpu")

# Extract node features
feats = graph.ndata.pop('feat').to(device)
in_dim = feats.shape[-1]

graph = graph.to(device)

# generate input
adj_orig = graph.adjacency_matrix().to_dense().to(device)
adj_orig = graph.adjacency_matrix().to_dense()

# build test set with 10% positive links
train_edge_idx, val_edges, val_edges_false, test_edges, test_edges_false = mask_test_edges_dgl(graph, adj_orig)

graph = graph.to(device)

# create train graph
train_edge_idx = torch.tensor(train_edge_idx).to(device)
train_graph = dgl.edge_subgraph(graph, train_edge_idx, preserve_nodes=True)
Expand All @@ -119,7 +121,7 @@ def dgl_main():
# Training and validation using a full graph
vgae_model.train()

logits = vgae_model.forward(train_graph, feats)
logits = vgae_model.forward(graph, feats)

# compute loss
loss = norm * F.binary_cross_entropy(logits.view(-1), adj.view(-1), weight=weight_tensor)
Expand Down

0 comments on commit cba5af2

Please sign in to comment.