Skip to content

Commit

Permalink
[Refactor] Interface of nn modules (dmlc#798)
Browse files Browse the repository at this point in the history
* refactor

* upd mpnn
  • Loading branch information
yzh119 authored Aug 27, 2019
1 parent 650f6ee commit 9314aab
Show file tree
Hide file tree
Showing 18 changed files with 194 additions and 184 deletions.
2 changes: 1 addition & 1 deletion examples/mxnet/gcn/gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,5 @@ def forward(self, features):
for i, layer in enumerate(self.layers):
if i != 0:
h = self.dropout(h)
h = layer(h, self.g)
h = layer(self.g, h)
return h
2 changes: 1 addition & 1 deletion examples/pytorch/appnp/appnp.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,5 @@ def forward(self, features):
h = self.activation(layer(h))
h = self.layers[-1](self.feat_drop(h))
# propagation step
h = self.propagate(h, self.g)
h = self.propagate(self.g, h)
return h
4 changes: 2 additions & 2 deletions examples/pytorch/gat/gat.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self,
def forward(self, inputs):
h = inputs
for l in range(self.num_layers):
h = self.gat_layers[l](h, self.g).flatten(1)
h = self.gat_layers[l](self.g, h).flatten(1)
# output projection
logits = self.gat_layers[-1](h, self.g).mean(1)
logits = self.gat_layers[-1](self.g, h).mean(1)
return logits
2 changes: 1 addition & 1 deletion examples/pytorch/gcn/gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ def forward(self, features):
for i, layer in enumerate(self.layers):
if i != 0:
h = self.dropout(h)
h = layer(h, self.g)
h = layer(self.g, h)
return h
2 changes: 1 addition & 1 deletion examples/pytorch/gin/gin.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def forward(self, g):
hidden_rep = [h]

for layer in range(self.num_layers - 1):
h = self.ginlayers[layer](h, g)
h = self.ginlayers[layer](g, h)
hidden_rep.append(h)

score_over_layer = 0
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/graphsage/graphsage.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self,
def forward(self, features):
h = features
for layer in self.layers:
h = layer(h, self.g)
h = layer(self.g, h)
return h


Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/model_zoo/citation_network/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
}

CHEBNET_CONFIG = {
'extra_args': [16, 1, 3, True],
'extra_args': [32, 1, 2, True],
'lr': 1e-2,
'weight_decay': 5e-4,
}
20 changes: 10 additions & 10 deletions examples/pytorch/model_zoo/citation_network/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def forward(self, features):
for i, layer in enumerate(self.layers):
if i != 0:
h = self.dropout(h)
h = layer(h, self.g)
h = layer(self.g, h)
return h


Expand Down Expand Up @@ -70,9 +70,9 @@ def __init__(self,
def forward(self, inputs):
h = inputs
for l in range(self.num_layers):
h = self.gat_layers[l](h, self.g).flatten(1)
h = self.gat_layers[l](self.g, h).flatten(1)
# output projection
logits = self.gat_layers[-1](h, self.g).mean(1)
logits = self.gat_layers[-1](self.g, h).mean(1)
return logits


Expand Down Expand Up @@ -101,7 +101,7 @@ def __init__(self,
def forward(self, features):
h = features
for layer in self.layers:
h = layer(h, self.g)
h = layer(self.g, h)
return h


Expand Down Expand Up @@ -148,7 +148,7 @@ def forward(self, features):
h = self.activation(layer(h))
h = self.layers[-1](self.feat_drop(h))
# propagation step
h = self.propagate(h, self.g)
h = self.propagate(self.g, h)
return h


Expand Down Expand Up @@ -178,7 +178,7 @@ def forward(self, features):
for i, layer in enumerate(self.layers):
if i != 0:
h = self.dropout(h)
h = layer(h, self.g)
h = layer(self.g, h)
return h


Expand Down Expand Up @@ -210,7 +210,7 @@ def __init__(self,
def forward(self, features):
h = self.proj(features)
for layer in self.layers:
h = layer(h, self.g)
h = layer(self.g, h)
return self.cls(h)


Expand All @@ -231,7 +231,7 @@ def __init__(self,
bias=bias)

def forward(self, features):
return self.net(features, self.g)
return self.net(self.g, features)


class GIN(nn.Module):
Expand Down Expand Up @@ -286,7 +286,7 @@ def __init__(self,
def forward(self, features):
h = features
for layer in self.layers:
h = layer(h, self.g)
h = layer(self.g, h)
return h

class ChebNet(nn.Module):
Expand Down Expand Up @@ -316,5 +316,5 @@ def __init__(self,
def forward(self, features):
h = features
for layer in self.layers:
h = layer(h, self.g)
h = layer(self.g, h, [2])
return h
4 changes: 2 additions & 2 deletions examples/pytorch/sgc/sgc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
def evaluate(model, g, features, labels, mask):
model.eval()
with torch.no_grad():
logits = model(features, g)[mask] # only compute the evaluation set
logits = model(g, features)[mask] # only compute the evaluation set
labels = labels[mask]
_, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels)
Expand Down Expand Up @@ -86,7 +86,7 @@ def main(args):
if epoch >= 3:
t0 = time.time()
# forward
logits = model(features, g) # only compute the train set
logits = model(g, features) # only compute the train set
loss = loss_fcn(logits[train_mask], labels[train_mask])

optimizer.zero_grad()
Expand Down
4 changes: 2 additions & 2 deletions examples/pytorch/sgc/sgc_reddit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def normalize(h):
def evaluate(model, features, graph, labels, mask):
model.eval()
with torch.no_grad():
logits = model(features, graph)[mask] # only compute the evaluation set
logits = model(graph, features)[mask] # only compute the evaluation set
labels = labels[mask]
_, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels)
Expand Down Expand Up @@ -82,7 +82,7 @@ def main(args):
# define loss closure
def closure():
optimizer.zero_grad()
output = model(features, g)[train_mask]
output = model(g, features)[train_mask]
loss_train = F.cross_entropy(output, labels[train_mask])
loss_train.backward()
return loss_train
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/tagcn/tagcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ def forward(self, features):
for i, layer in enumerate(self.layers):
if i != 0:
h = self.dropout(h)
h = layer(h, self.g)
h = layer(self.g, h)
return h
2 changes: 1 addition & 1 deletion python/dgl/model_zoo/chem/mpnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def forward(self, g):
out, h = self.gru(m.unsqueeze(0), h)
out = out.squeeze(0)

out = self.set2set(out, g)
out = self.set2set(g, out)
out = F.relu(self.lin1(out))
out = self.lin2(out)
return out
6 changes: 3 additions & 3 deletions python/dgl/nn/mxnet/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(self,

self._activation = activation

def forward(self, feat, graph):
def forward(self, graph, feat):
r"""Compute graph convolution.
Notes
Expand All @@ -95,10 +95,10 @@ def forward(self, feat, graph):
Parameters
----------
feat : mxnet.NDArray
The input feature
graph : DGLGraph
The graph.
feat : mxnet.NDArray
The input feature
Returns
-------
Expand Down
36 changes: 18 additions & 18 deletions python/dgl/nn/mxnet/glob.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ class SumPooling(nn.Block):
def __init__(self):
super(SumPooling, self).__init__()

def forward(self, feat, graph):
def forward(self, graph, feat):
r"""Compute sum pooling.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
Expand Down Expand Up @@ -56,16 +56,16 @@ class AvgPooling(nn.Block):
def __init__(self):
super(AvgPooling, self).__init__()

def forward(self, feat, graph):
def forward(self, graph, feat):
r"""Compute average pooling.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
Expand Down Expand Up @@ -93,16 +93,16 @@ class MaxPooling(nn.Block):
def __init__(self):
super(MaxPooling, self).__init__()

def forward(self, feat, graph):
def forward(self, graph, feat):
r"""Compute max pooling.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
Expand Down Expand Up @@ -134,16 +134,16 @@ def __init__(self, k):
super(SortPooling, self).__init__()
self.k = k

def forward(self, feat, graph):
def forward(self, graph, feat):
r"""Compute sort pooling.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
Expand Down Expand Up @@ -190,16 +190,16 @@ def __init__(self, gate_nn, feat_nn=None):
self.gate_nn = gate_nn
self.feat_nn = feat_nn

def forward(self, feat, graph):
def forward(self, graph, feat):
r"""Compute global attention pooling.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
Expand Down Expand Up @@ -258,16 +258,16 @@ def __init__(self, input_dim, n_iters, n_layers):
self.lstm = gluon.rnn.LSTM(
self.input_dim, num_layers=n_layers, input_size=self.output_dim)

def forward(self, feat, graph):
def forward(self, graph, feat):
r"""Compute set2set pooling.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
graph : DGLGraph or BatchedDGLGraph
The graph.
Returns
-------
Expand Down
Loading

0 comments on commit 9314aab

Please sign in to comment.