Skip to content

Commit

Permalink
[MXNet][API] move to the new API (dmlc#123)
Browse files Browse the repository at this point in the history
* move gat to the new api.

* fix gcn.

* update sse.

* fix dgl core.

* update sse.

* fix small bugs in dgl core.

* fix mxnet tests.

* retrigger

* address comments and fix more bugs.

* fix

* fix tests.
  • Loading branch information
zheng-da authored Nov 8, 2018
1 parent 1eb17bb commit bd0e4fa
Show file tree
Hide file tree
Showing 10 changed files with 278 additions and 120 deletions.
20 changes: 10 additions & 10 deletions examples/mxnet/gat/gat_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@
def elu(data):
return mx.nd.LeakyReLU(data, act_type='elu')

def gat_message(src, edge):
return {'ft' : src['ft'], 'a2' : src['a2']}
def gat_message(edges):
return {'ft' : edges.src['ft'], 'a2' : edges.src['a2']}

class GATReduce(gluon.Block):
def __init__(self, attn_drop):
super(GATReduce, self).__init__()
self.attn_drop = attn_drop

def forward(self, node, msgs):
a1 = mx.nd.expand_dims(node['a1'], 1) # shape (B, 1, 1)
a2 = msgs['a2'] # shape (B, deg, 1)
ft = msgs['ft'] # shape (B, deg, D)
def forward(self, nodes):
a1 = mx.nd.expand_dims(nodes.data['a1'], 1) # shape (B, 1, 1)
a2 = nodes.mailbox['a2'] # shape (B, deg, 1)
ft = nodes.mailbox['ft'] # shape (B, deg, D)
# attention
a = a1 + a2 # shape (B, deg, 1)
e = mx.nd.softmax(mx.nd.LeakyReLU(a))
Expand All @@ -48,13 +48,13 @@ def __init__(self, headid, indim, hiddendim, activation, residual):
if indim != hiddendim:
self.residual_fc = gluon.nn.Dense(hiddendim)

def forward(self, node):
ret = node['accum']
def forward(self, nodes):
ret = nodes.data['accum']
if self.residual:
if self.residual_fc is not None:
ret = self.residual_fc(node['h']) + ret
ret = self.residual_fc(nodes.data['h']) + ret
else:
ret = node['h'] + ret
ret = nodes.data['h'] + ret
return {'head%d' % self.headid : self.activation(ret)}

class GATPrepare(gluon.Block):
Expand Down
18 changes: 9 additions & 9 deletions examples/mxnet/gcn/gcn_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@
from dgl import DGLGraph
from dgl.data import register_data_args, load_data

def gcn_msg(src, edge):
return src
def gcn_msg(edge):
return {'m': edge.src['h']}

def gcn_reduce(node, msgs):
return mx.nd.sum(msgs, 1)
def gcn_reduce(node):
return {'accum': mx.nd.sum(node.mailbox['m'], 1)}

class NodeUpdateModule(gluon.Block):
def __init__(self, out_feats, activation=None):
super(NodeUpdateModule, self).__init__()
self.linear = gluon.nn.Dense(out_feats, activation=activation)

def forward(self, node):
return self.linear(node)
return {'h': self.linear(node.data['accum'])}

class GCN(gluon.Block):
def __init__(self,
Expand All @@ -50,14 +50,14 @@ def __init__(self,
self.layers.add(NodeUpdateModule(n_classes))

def forward(self, features):
self.g.set_n_repr(features)
self.g.ndata['h'] = features
for layer in self.layers:
# apply dropout
if self.dropout:
val = F.dropout(self.g.get_n_repr(), p=self.dropout)
self.g.set_n_repr(val)
val = F.dropout(self.g.ndata['h'], p=self.dropout)
self.g.ndata['h'] = val
self.g.update_all(gcn_msg, gcn_reduce, layer)
return self.g.pop_n_repr()
return self.g.ndata.pop('h')

def main(args):
# load and preprocess dataset
Expand Down
Loading

0 comments on commit bd0e4fa

Please sign in to comment.