Skip to content

Commit

Permalink
[Model] Add DGI Model (dmlc#501)
Browse files Browse the repository at this point in the history
  • Loading branch information
xavierzw authored and yzh119 committed Apr 22, 2019
1 parent fe7d5e9 commit ad9da36
Show file tree
Hide file tree
Showing 5 changed files with 323 additions and 1 deletion.
38 changes: 38 additions & 0 deletions examples/pytorch/dgi/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
Deep Graph Infomax (DGI)
========================

- Paper link: [https://arxiv.org/abs/1809.10341](https://arxiv.org/abs/1809.10341)
- Author's code repo (in Pytorch):
[https://github.com/PetarV-/DGI](https://github.com/PetarV-/DGI)

Dependencies
------------
- PyTorch 0.4.1+
- requests

```bash
pip install torch requests
```

How to run
----------

Run with following:

```bash
python train.py --dataset=cora --gpu=0 --self-loop
```

```bash
python train.py --dataset=citeseer --gpu=0
```

```bash
python train.py --dataset=pubmed --gpu=0
```

Results
-------
* cora: ~81.6 (81.2-82.1) (paper: 82.3)
* citeseer: ~69.4 (paper: 71.8)
* pubmed: ~76.1 (paper: 76.8)
81 changes: 81 additions & 0 deletions examples/pytorch/dgi/dgi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""
Deep Graph Infomax in DGL
References
----------
Papers: https://arxiv.org/abs/1809.10341
Author's code: https://github.com/PetarV-/DGI
"""

import torch
import torch.nn as nn
import math
from gcn import GCN

class Encoder(nn.Module):
def __init__(self, g, in_feats, n_hidden, n_layers, activation, dropout):
super(Encoder, self).__init__()
self.g = g
self.conv = GCN(g, in_feats, n_hidden, n_hidden, n_layers, activation, dropout)

def forward(self, features, corrupt=False):
if corrupt:
perm = torch.randperm(self.g.number_of_nodes())
features = features[perm]
features = self.conv(features)
return features


class Discriminator(nn.Module):
def __init__(self, n_hidden):
super(Discriminator, self).__init__()
self.weight = nn.Parameter(torch.Tensor(n_hidden, n_hidden))
self.reset_parameters()

def uniform(self, size, tensor):
bound = 1.0 / math.sqrt(size)
if tensor is not None:
tensor.data.uniform_(-bound, bound)

def reset_parameters(self):
size = self.weight.size(0)
self.uniform(size, self.weight)

def forward(self, features, summary):
features = torch.matmul(features, torch.matmul(self.weight, summary))
return features


class DGI(nn.Module):
def __init__(self, g, in_feats, n_hidden, n_layers, activation, dropout):
super(DGI, self).__init__()
self.encoder = Encoder(g, in_feats, n_hidden, n_layers, activation, dropout)
self.discriminator = Discriminator(n_hidden)
self.loss = nn.BCEWithLogitsLoss()

def forward(self, features):
positive = self.encoder(features, corrupt=False)
negative = self.encoder(features, corrupt=True)
summary = torch.sigmoid(positive.mean(dim=0))

positive = self.discriminator(positive, summary)
negative = self.discriminator(negative, summary)

l1 = self.loss(positive, torch.ones_like(positive))
l2 = self.loss(negative, torch.zeros_like(negative))

return l1 + l2


class Classifier(nn.Module):
def __init__(self, n_hidden, n_classes):
super(Classifier, self).__init__()
self.fc = nn.Linear(n_hidden, n_classes)
self.reset_parameters()

def reset_parameters(self):
self.fc.reset_parameters()

def forward(self, features):
features = self.fc(features)
return torch.log_softmax(features, dim=-1)
35 changes: 35 additions & 0 deletions examples/pytorch/dgi/gcn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""
This code was copied from the GCN implementation in DGL examples.
"""
import torch
import torch.nn as nn
from dgl.nn.pytorch import GraphConv

class GCN(nn.Module):
def __init__(self,
g,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout):
super(GCN, self).__init__()
self.g = g
self.layers = nn.ModuleList()
# input layer
self.layers.append(GraphConv(in_feats, n_hidden, activation=activation))
# hidden layers
for i in range(n_layers - 1):
self.layers.append(GraphConv(n_hidden, n_hidden, activation=activation))
# output layer
self.layers.append(GraphConv(n_hidden, n_classes))
self.dropout = nn.Dropout(p=dropout)

def forward(self, features):
h = features
for i, layer in enumerate(self.layers):
if i != 0:
h = self.dropout(h)
h = layer(h, self.g)
return h
168 changes: 168 additions & 0 deletions examples/pytorch/dgi/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import argparse, time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
from dgi import DGI, Classifier

def evaluate(model, features, labels, mask):
model.eval()
with torch.no_grad():
logits = model(features)
logits = logits[mask]
labels = labels[mask]
_, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)

def main(args):
# load and preprocess dataset
data = load_data(args)
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()

if args.gpu < 0:
cuda = False
else:
cuda = True
torch.cuda.set_device(args.gpu)
features = features.cuda()
labels = labels.cuda()
train_mask = train_mask.cuda()
val_mask = val_mask.cuda()
test_mask = test_mask.cuda()

# graph preprocess
g = data.graph
# add self loop
if args.self_loop:
g.remove_edges_from(g.selfloop_edges())
g.add_edges_from(zip(g.nodes(), g.nodes()))
g = DGLGraph(g)
n_edges = g.number_of_edges()

# create DGI model
dgi = DGI(g,
in_feats,
args.n_hidden,
args.n_layers,
nn.PReLU(args.n_hidden),
args.dropout)

if cuda:
dgi.cuda()

dgi_optimizer = torch.optim.Adam(dgi.parameters(),
lr=args.dgi_lr,
weight_decay=args.weight_decay)

# train deep graph infomax
cnt_wait = 0
best = 1e9
best_t = 0
dur = []
for epoch in range(args.n_dgi_epochs):
dgi.train()
if epoch >= 3:
t0 = time.time()

dgi_optimizer.zero_grad()
loss = dgi(features)
loss.backward()
dgi_optimizer.step()

if loss < best:
best = loss
best_t = epoch
cnt_wait = 0
torch.save(dgi.state_dict(), 'best_dgi.pkl')
else:
cnt_wait += 1

if cnt_wait == args.patience:
print('Early stopping!')
break

if epoch >= 3:
dur.append(time.time() - t0)

print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | "
"ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.item(),
n_edges / np.mean(dur) / 1000))

# create classifier model
classifier = Classifier(args.n_hidden, n_classes)
if cuda:
classifier.cuda()

classifier_optimizer = torch.optim.Adam(classifier.parameters(),
lr=args.classifier_lr,
weight_decay=args.weight_decay)

# train classifier
print('Loading {}th epoch'.format(best_t))
dgi.load_state_dict(torch.load('best_dgi.pkl'))
embeds = dgi.encoder(features, corrupt=False)
embeds = embeds.detach()
dur = []
for epoch in range(args.n_classifier_epochs):
classifier.train()
if epoch >= 3:
t0 = time.time()

classifier_optimizer.zero_grad()
preds = classifier(embeds)
loss = F.nll_loss(preds[train_mask], labels[train_mask])
loss.backward()
classifier_optimizer.step()

if epoch >= 3:
dur.append(time.time() - t0)

acc = evaluate(classifier, embeds, labels, val_mask)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
"ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.item(),
acc, n_edges / np.mean(dur) / 1000))

print()
acc = evaluate(classifier, embeds, labels, test_mask)
print("Test Accuracy {:.4f}".format(acc))

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='DGI')
register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0.,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--dgi-lr", type=float, default=1e-3,
help="dgi learning rate")
parser.add_argument("--classifier-lr", type=float, default=1e-2,
help="classifier learning rate")
parser.add_argument("--n-dgi-epochs", type=int, default=300,
help="number of training epochs")
parser.add_argument("--n-classifier-epochs", type=int, default=300,
help="number of training epochs")
parser.add_argument("--n-hidden", type=int, default=512,
help="number of hidden gcn units")
parser.add_argument("--n-layers", type=int, default=1,
help="number of hidden gcn layers")
parser.add_argument("--weight-decay", type=float, default=0.,
help="Weight for L2 loss")
parser.add_argument("--patience", type=int, default=20,
help="early stop patience condition")
parser.add_argument("--self-loop", action='store_true',
help="graph self-loop (default=False)")
parser.set_defaults(self_loop=False)
args = parser.parse_args()
print(args)

main(args)
2 changes: 1 addition & 1 deletion python/dgl/data/citation_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def _normalize(mx):
"""Row-normalize sparse matrix"""
rowsum = np.array(mx.sum(1))
r_inv = np.power(rowsum, -1).flatten()
r_inv[np.isinf(r_inv)] = np.inf
r_inv[np.isinf(r_inv)] = 0.
r_mat_inv = sp.diags(r_inv)
mx = r_mat_inv.dot(mx)
return mx
Expand Down

0 comments on commit ad9da36

Please sign in to comment.