Skip to content

Commit

Permalink
[Model] Update GAT model code (dmlc#622)
Browse files Browse the repository at this point in the history
* fix gat code to use latest edge softmax module

* avoid transpose

* update README

* use edge_softmax op

* mxnet edge softmax op

* mxnet gat

* update README

* fix unittest

* fix ci

* fix mxnet nn test; relax criteria for prod reducer
  • Loading branch information
lingfanyu authored and jermainewang committed Jun 9, 2019
1 parent 993fd3f commit 74e13ee
Show file tree
Hide file tree
Showing 16 changed files with 483 additions and 362 deletions.
21 changes: 11 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,17 @@ maintaining high computation efficiency.

A summary of the model accuracy and training speed with the Pytorch backend (on Amazon EC2 p3.2x instance (w/ V100 GPU)), as compared with the best open-source implementations:

| Model | Reported <br> Accuracy | DGL <br> Accuracy | Author's training speed (epoch time) | DGL speed (epoch time) | Improvement |
| ----- | ----------------- | ------------ | ------------------------------------ | ---------------------- | ----------- |
| [GCN](https://arxiv.org/abs/1609.02907) | 81.5% | 81.0% | [0.0051s (TF)](https://github.com/tkipf/gcn) | 0.0038s | 1.34x |
| [SGC](https://arxiv.org/abs/1902.07153) | 81.0% | 81.9% | n/a | 0.0008s | n/a |
| [TreeLSTM](http://arxiv.org/abs/1503.00075) | 51.0% | 51.72% | [14.02s (DyNet)](https://github.com/clab/dynet/tree/master/examples/treelstm) | 3.18s | 4.3x |
| [R-GCN <br> (classification)](https://arxiv.org/abs/1703.06103) | 73.23% | 73.53% | [0.2853s (Theano)](https://github.com/tkipf/relational-gcn) | 0.0097s | 29.4x |
| [R-GCN <br> (link prediction)](https://arxiv.org/abs/1703.06103) | 0.158 | 0.151 | [2.204s (TF)](https://github.com/MichSchli/RelationPrediction) | 0.453s | 4.86x |
| [JTNN](https://arxiv.org/abs/1802.04364) | 96.44% | 96.44% | [1826s (Pytorch)](https://github.com/wengong-jin/icml18-jtnn) | 743s | 2.5x |
| [LGNN](https://arxiv.org/abs/1705.08415) | 94% | 94% | n/a | 1.45s | n/a |
| [DGMG](https://arxiv.org/pdf/1803.03324.pdf) | 84% | 90% | n/a | 238s | n/a |
| Model | Reported <br> Accuracy | DGL <br> Accuracy | Author's training speed (epoch time) | DGL speed (epoch time) | Improvement |
| ----- | ----------------- | ------------ | ------------------------------------ | ---------------------- | ----------- |
| [GCN](https://arxiv.org/abs/1609.02907) | 81.5% | 81.0% | [0.0051s (TF)](https://github.com/tkipf/gcn) | 0.0038s | 1.34x |
| [GAT](https://arxiv.org/abs/1710.10903) | 83.0% | 83.9% | [0.0982s (TF)](https://github.com/PetarV-/GAT) | 0.0076s | 12.9x |
| [SGC](https://arxiv.org/abs/1902.07153) | 81.0% | 81.9% | n/a | 0.0008s | n/a |
| [TreeLSTM](http://arxiv.org/abs/1503.00075) | 51.0% | 51.72% | [14.02s (DyNet)](https://github.com/clab/dynet/tree/master/examples/treelstm) | 3.18s | 4.3x |
| [R-GCN <br> (classification)](https://arxiv.org/abs/1703.06103) | 73.23% | 73.53% | [0.2853s (Theano)](https://github.com/tkipf/relational-gcn) | 0.0097s | 29.4x |
| [R-GCN <br> (link prediction)](https://arxiv.org/abs/1703.06103) | 0.158 | 0.151 | [2.204s (TF)](https://github.com/MichSchli/RelationPrediction) | 0.453s | 4.86x |
| [JTNN](https://arxiv.org/abs/1802.04364) | 96.44% | 96.44% | [1826s (Pytorch)](https://github.com/wengong-jin/icml18-jtnn) | 743s | 2.5x |
| [LGNN](https://arxiv.org/abs/1705.08415) | 94% | 94% | n/a | 1.45s | n/a |
| [DGMG](https://arxiv.org/pdf/1803.03324.pdf) | 84% | 90% | n/a | 238s | n/a |

With the MXNet/Gluon backend , we scaled a graph of 50M nodes and 150M edges on a P3.8xlarge instance,
with 160s per epoch, on SSE ([Stochastic Steady-state Embedding](https://www.cc.gatech.edu/~hdai8/pdf/equilibrium_embedding.pdf)),
Expand Down
2 changes: 1 addition & 1 deletion examples/mxnet/gat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ pip install requests

### Usage (make sure that DGLBACKEND is changed into mxnet)
```bash
DGLBACKEND=mxnet python3 gat_batch.py --dataset cora --gpu 0 --num-heads 8
DGLBACKEND=mxnet python3 train.py --dataset cora --gpu 0 --num-heads 8
```
135 changes: 135 additions & 0 deletions examples/mxnet/gat/gat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""
Graph Attention Networks in DGL using SPMV optimization.
References
----------
Paper: https://arxiv.org/abs/1710.10903
Author's code: https://github.com/PetarV-/GAT
Pytorch implementation: https://github.com/Diego999/pyGAT
"""

import mxnet as mx
from mxnet import gluon
import mxnet.ndarray as nd
import mxnet.gluon.nn as nn
import dgl.function as fn
from dgl.nn.mxnet import edge_softmax


class GraphAttention(gluon.Block):
def __init__(self,
g,
in_dim,
out_dim,
num_heads,
feat_drop,
attn_drop,
alpha,
residual=False):
super(GraphAttention, self).__init__()
self.g = g
self.num_heads = num_heads
self.fc = nn.Dense(num_heads * out_dim, use_bias=False,
weight_initializer=mx.init.Xavier())
if feat_drop:
self.feat_drop = nn.Dropout(feat_drop)
else:
self.feat_drop = lambda x : x
if attn_drop:
self.attn_drop = nn.Dropout(attn_drop)
else:
self.attn_drop = lambda x : x
self.attn_l = self.params.get("left_att", grad_req="add",
shape=(1, num_heads, out_dim),
init=mx.init.Xavier())
self.attn_r = self.params.get("right_att", grad_req="add",
shape=(1, num_heads, out_dim),
init=mx.init.Xavier())
self.alpha = alpha
self.softmax = edge_softmax
self.residual = residual
if residual:
if in_dim != out_dim:
self.res_fc = nn.Dense(num_heads * out_dim, use_bias=False,
weight_initializer=mx.init.Xavier())
else:
self.res_fc = None

def forward(self, inputs):
# prepare
h = self.feat_drop(inputs) # NxD
ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1)) # NxHxD'
a1 = (ft * self.attn_l.data(ft.context)).sum(axis=-1).expand_dims(-1) # N x H x 1
a2 = (ft * self.attn_r.data(ft.context)).sum(axis=-1).expand_dims(-1) # N x H x 1
self.g.ndata.update({'ft' : ft, 'a1' : a1, 'a2' : a2})
# 1. compute edge attention
self.g.apply_edges(self.edge_attention)
# 2. compute softmax
self.edge_softmax()
# 3. compute the aggregated node features
self.g.update_all(fn.src_mul_edge('ft', 'a_drop', 'ft'),
fn.sum('ft', 'ft'))
ret = self.g.ndata['ft']
# 4. residual
if self.residual:
if self.res_fc is not None:
resval = self.res_fc(h).reshape(
(h.shape[0], self.num_heads, -1)) # NxHxD'
else:
resval = nd.expand_dims(h, axis=1) # Nx1xD'
ret = resval + ret
return ret

def edge_attention(self, edges):
# an edge UDF to compute unnormalized attention values from src and dst
a = nd.LeakyReLU(edges.src['a1'] + edges.dst['a2'], slope=self.alpha)
return {'a' : a}

def edge_softmax(self):
attention = self.softmax(self.g, self.g.edata.pop('a'))
# Dropout attention scores and save them
self.g.edata['a_drop'] = self.attn_drop(attention)


class GAT(nn.Block):
def __init__(self,
g,
num_layers,
in_dim,
num_hidden,
num_classes,
heads,
activation,
feat_drop,
attn_drop,
alpha,
residual):
super(GAT, self).__init__()
self.g = g
self.num_layers = num_layers
self.gat_layers = []
self.activation = activation
# input projection (no residual)
self.gat_layers.append(GraphAttention(
g, in_dim, num_hidden, heads[0],
feat_drop, attn_drop, alpha, False))
# hidden layers
for l in range(1, num_layers):
# due to multi-head, the in_dim = num_hidden * num_heads
self.gat_layers.append(GraphAttention(
g, num_hidden * heads[l-1], num_hidden, heads[l],
feat_drop, attn_drop, alpha, residual))
# output projection
self.gat_layers.append(GraphAttention(
g, num_hidden * heads[-2], num_classes, heads[-1],
feat_drop, attn_drop, alpha, residual))
for i, layer in enumerate(self.gat_layers):
self.register_child(layer, "gat_layer_{}".format(i))

def forward(self, inputs):
h = inputs
for l in range(self.num_layers):
h = self.gat_layers[l](h).flatten()
h = self.activation(h)
# output projection
logits = self.gat_layers[-1](h).mean(1)
return logits
Loading

0 comments on commit 74e13ee

Please sign in to comment.