Skip to content

Commit

Permalink
[NN] Add TAGCN nn.module and example (dmlc#788)
Browse files Browse the repository at this point in the history
* upd

* fig edgebatch edges

* add test

* trigger

* Update README.md for pytorch PinSage example.

Add noting that the PinSage model example under
example/pytorch/recommendation only work with Python 3.6+
as its dataset loader depends on stanfordnlp package
which work only with Python 3.6+.

* Provid a frame agnostic API to test nn modules on both CPU and CUDA side.

1. make dgl.nn.xxx frame agnostic
2. make test.backend include dgl.nn modules
3. modify test_edge_softmax of test/mxnet/test_nn.py and
    test/pytorch/test_nn.py work on both CPU and GPU

* Fix style

* Delete unused code

* Make agnostic test only related to tests/backend

1. clear all agnostic related code in dgl.nn
2. make test_graph_conv agnostic to cpu/gpu

* Fix code style

* fix

* doc

* Make all test code under tests.mxnet/pytorch.test_nn.py
work on both CPU and GPU.

* Fix syntex

* Remove rand

* Add TAGCN nn.module and example

* Now tagcn can run on CPU.

* Add unitest for TGConv

* Fix style

* For pubmed dataset, using --lr=0.005 can achieve better acc

* Fix style

* Fix some descriptions

* trigger

* Fix doc
  • Loading branch information
classicsong authored and yzh119 committed Aug 25, 2019
1 parent 708765f commit 11fb217
Show file tree
Hide file tree
Showing 5 changed files with 332 additions and 1 deletion.
24 changes: 24 additions & 0 deletions examples/pytorch/tagcn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
Topology Adaptive Graph Convolutional networks (TAGCN)
============

- Paper link: [https://arxiv.org/abs/1710.10370](https://arxiv.org/abs/1710.10370)

Dependencies
------------
- PyTorch 0.4.1+
- requests

``bash
pip install torch requests
``

Results
-------
Run with following (available dataset: "cora", "citeseer", "pubmed")
```bash
python3 train.py --dataset cora --gpu 0 --self-loop
```

* cora: ~0.812 (0.804-0.823) (paper: 0.833)
* citeseer: ~0.715 (paper: 0.714)
* pubmed: ~0.794 (paper: 0.811)
39 changes: 39 additions & 0 deletions examples/pytorch/tagcn/tagcn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""GCN using DGL nn package
References:
- Semi-Supervised Classification with Graph Convolutional Networks
- Paper: https://arxiv.org/abs/1609.02907
- Code: https://github.com/tkipf/gcn
"""
import torch
import torch.nn as nn
from dgl.nn.pytorch.conv import TGConv

class TAGCN(nn.Module):
def __init__(self,
g,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout):
super(TAGCN, self).__init__()
self.g = g
self.layers = nn.ModuleList()
# input layer
self.layers.append(TGConv(in_feats, n_hidden, activation=activation))
# hidden layers
for i in range(n_layers - 1):
self.layers.append(TGConv(n_hidden, n_hidden, activation=activation))
# output layer
self.layers.append(TGConv(n_hidden, n_classes)) #activation=None
self.dropout = nn.Dropout(p=dropout)

def forward(self, features):
h = features
for i, layer in enumerate(self.layers):
if i != 0:
h = self.dropout(h)
h = layer(h, self.g)
return h
131 changes: 131 additions & 0 deletions examples/pytorch/tagcn/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import argparse, time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
from dgl.data import register_data_args, load_data

from tagcn import TAGCN

def evaluate(model, features, labels, mask):
model.eval()
with torch.no_grad():
logits = model(features)
logits = logits[mask]
labels = labels[mask]
_, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)

def main(args):
# load and preprocess dataset
data = load_data(args)
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
print("""----Data statistics------'
#Edges %d
#Classes %d
#Train samples %d
#Val samples %d
#Test samples %d""" %
(n_edges, n_classes,
train_mask.sum().item(),
val_mask.sum().item(),
test_mask.sum().item()))

if args.gpu < 0:
cuda = False
else:
cuda = True
torch.cuda.set_device(args.gpu)
features = features.cuda()
labels = labels.cuda()
train_mask = train_mask.cuda()
val_mask = val_mask.cuda()
test_mask = test_mask.cuda()

# graph preprocess and calculate normalization factor
g = data.graph
# add self loop
if args.self_loop:
g.remove_edges_from(g.selfloop_edges())
g.add_edges_from(zip(g.nodes(), g.nodes()))
g = DGLGraph(g)
n_edges = g.number_of_edges()

# create TAGCN model
model = TAGCN(g,
in_feats,
args.n_hidden,
n_classes,
args.n_layers,
F.relu,
args.dropout)

if cuda:
model.cuda()
loss_fcn = torch.nn.CrossEntropyLoss()

# use optimizer
optimizer = torch.optim.Adam(model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay)

# initialize graph
dur = []
for epoch in range(args.n_epochs):
model.train()
if epoch >= 3:
t0 = time.time()
# forward
logits = model(features)
loss = loss_fcn(logits[train_mask], labels[train_mask])

optimizer.zero_grad()
loss.backward()
optimizer.step()

if epoch >= 3:
dur.append(time.time() - t0)

acc = evaluate(model, features, labels, val_mask)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
"ETputs(KTEPS) {:.2f}". format(epoch, np.mean(dur), loss.item(),
acc, n_edges / np.mean(dur) / 1000))

print()
acc = evaluate(model, features, labels, test_mask)
print("Test Accuracy {:.4f}".format(acc))


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='TAGCN')
register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0.5,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=1e-2,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=200,
help="number of training epochs")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden tagcn units")
parser.add_argument("--n-layers", type=int, default=1,
help="number of hidden tagcn layers")
parser.add_argument("--weight-decay", type=float, default=5e-4,
help="Weight for L2 loss")
parser.add_argument("--self-loop", action='store_true',
help="graph self-loop (default=False)")
parser.set_defaults(self_loop=False)
args = parser.parse_args()
print(args)

main(args)
91 changes: 90 additions & 1 deletion python/dgl/nn/pytorch/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from . import utils
from ... import function as fn

__all__ = ['GraphConv', 'RelGraphConv']
__all__ = ['GraphConv', 'TGConv', 'RelGraphConv']

class GraphConv(nn.Module):
r"""Apply graph convolution over an input signal.
Expand Down Expand Up @@ -150,6 +150,95 @@ def extra_repr(self):
summary += ', activation={_activation}'
return summary.format(**self.__dict__)

class TGConv(nn.Module):
r"""Apply Topology Adaptive Graph Convolutional Network
.. math::
\mathbf{X}^{\prime} = \sum_{k=0}^K \mathbf{D}^{-1/2} \mathbf{A}
\mathbf{D}^{-1/2}\mathbf{X} \mathbf{\Theta}_{k},
where :math:`\mathbf{A}` denotes the adjacency matrix and
:math:`D_{ii} = \sum_{j=0} A_{ij}` its diagonal degree matrix.
Parameters
----------
in_feats : int
Number of input features.
out_feats : int
Number of output features.
k: int, optional
Number of hops :math: `k`. (default: 3)
bias: bool, optional
If True, adds a learnable bias to the output. Default: ``True``.
activation: callable activation function/layer or None, optional
If not None, applies an activation function to the updated node features.
Default: ``None``.
Attributes
----------
lin : torch.Module
The learnable linear module.
"""
def __init__(self,
in_feats,
out_feats,
k=2,
bias=True,
activation=None):
super(TGConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
self._k = k
self._activation = activation
self.lin = nn.Linear(in_feats * (self._k + 1), out_feats, bias=bias)

self.reset_parameters()

def reset_parameters(self):
"""Reinitialize learnable parameters."""
self.lin.reset_parameters()

def forward(self, feat, graph):
r"""Compute graph convolution
Parameters
----------
feat : torch.Tensor
The input feature
graph : DGLGraph
The graph.
Returns
-------
torch.Tensor
The output feature
"""
graph = graph.local_var()

norm = th.pow(graph.in_degrees().float(), -0.5)
shp = norm.shape + (1,) * (feat.dim() - 1)
norm = th.reshape(norm, shp).to(feat.device)

#D-1/2 A D -1/2 X
fstack = [feat]
for _ in range(self._k):

rst = fstack[-1] * norm
graph.ndata['h'] = rst

graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'))
rst = graph.ndata['h']
rst = rst * norm
fstack.append(rst)

rst = self.lin(th.cat(fstack, dim=-1))

if self._activation is not None:
rst = self._activation(rst)

return rst

class RelGraphConv(nn.Module):
r"""Relational graph convolution layer.
Expand Down
48 changes: 48 additions & 0 deletions tests/pytorch/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,54 @@ def test_graph_conv():
new_weight = conv.weight.data
assert not F.allclose(old_weight, new_weight)

def _S2AXWb(A, N, X, W, b):
X1 = X * N
X1 = th.matmul(A, X1.view(X1.shape[0], -1))
X1 = X1 * N
X2 = X1 * N
X2 = th.matmul(A, X2.view(X2.shape[0], -1))
X2 = X2 * N
X = th.cat([X, X1, X2], dim=-1)
Y = th.matmul(X, W.rot90())

return Y + b

def test_tgconv():
g = dgl.DGLGraph(nx.path_graph(3))
ctx = F.ctx()
adj = g.adjacency_matrix(ctx=ctx)
norm = th.pow(g.in_degrees().float(), -0.5)

conv = nn.TGConv(5, 2, bias=True)
if F.gpu_ctx():
conv.cuda()
print(conv)

# test#1: basic
h0 = F.ones((3, 5))
h1 = conv(h0, g)
assert len(g.ndata) == 0
assert len(g.edata) == 0
shp = norm.shape + (1,) * (h0.dim() - 1)
norm = th.reshape(norm, shp).to(ctx)

assert F.allclose(h1, _S2AXWb(adj, norm, h0, conv.lin.weight, conv.lin.bias))

conv = nn.TGConv(5, 2)
if F.gpu_ctx():
conv.cuda()
# test#2: basic
h0 = F.ones((3, 5))
h1 = conv(h0, g)
assert len(g.ndata) == 0
assert len(g.edata) == 0

# test rest_parameters
old_weight = deepcopy(conv.lin.weight.data)
conv.reset_parameters()
new_weight = conv.lin.weight.data
assert not F.allclose(old_weight, new_weight)

def test_set2set():
g = dgl.DGLGraph(nx.path_graph(10))

Expand Down

0 comments on commit 11fb217

Please sign in to comment.