Skip to content

Commit

Permalink
[NN] JumpingKnowledge (dmlc#3512)
Browse files Browse the repository at this point in the history
* Update

* Fix
  • Loading branch information
mufeili authored Nov 19, 2021
1 parent 3aef467 commit 9e7fbf9
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 23 deletions.
7 changes: 7 additions & 0 deletions docs/source/api/python/nn.pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,13 @@ SegmentedKNNGraph
:members:
:show-inheritance:

JumpingKnowledge
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.pytorch.utils.JumpingKnowledge
:members: forward, reset_parameters
:show-inheritance:

NodeEmbedding Module
----------------------------------------

Expand Down
32 changes: 10 additions & 22 deletions examples/pytorch/jknet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
from dgl.nn.pytorch.conv import GraphConv
from dgl.nn import GraphConv, JumpingKnowledge

class JKNet(nn.Module):
def __init__(self,
Expand All @@ -13,19 +13,21 @@ def __init__(self,
mode='cat',
dropout=0.):
super(JKNet, self).__init__()

self.mode = mode
self.dropout = nn.Dropout(dropout)
self.layers = nn.ModuleList()
self.layers.append(GraphConv(in_dim, hid_dim, activation=F.relu))
for _ in range(num_layers):
self.layers.append(GraphConv(hid_dim, hid_dim, activation=F.relu))

if self.mode == 'lstm':
self.jump = JumpingKnowledge(mode, hid_dim, num_layers)
else:
self.jump = JumpingKnowledge(mode)

if self.mode == 'cat':
hid_dim = hid_dim * (num_layers + 1)
elif self.mode == 'lstm':
self.lstm = nn.LSTM(hid_dim, (num_layers * hid_dim) // 2, bidirectional=True, batch_first=True)
self.attn = nn.Linear(2 * ((num_layers * hid_dim) // 2), 1)

self.output = nn.Linear(hid_dim, out_dim)
self.reset_params()
Expand All @@ -34,29 +36,15 @@ def reset_params(self):
self.output.reset_parameters()
for layers in self.layers:
layers.reset_parameters()
if self.mode == 'lstm':
self.lstm.reset_parameters()
self.attn.reset_parameters()
self.jump.reset_parameters()

def forward(self, g, feats):
feat_lst = []
for layer in self.layers:
feats = self.dropout(layer(g, feats))
feat_lst.append(feats)

if self.mode == 'cat':
out = torch.cat(feat_lst, dim=-1)
elif self.mode == 'max':
out = torch.stack(feat_lst, dim=-1).max(dim=-1)[0]
else:
# lstm
x = torch.stack(feat_lst, dim=1)
alpha, _ = self.lstm(x)
alpha = self.attn(alpha).squeeze(-1)
alpha = torch.softmax(alpha, dim=-1).unsqueeze(-1)
out = (x * alpha).sum(dim=1)

g.ndata['h'] = out

g.ndata['h'] = self.jump(feat_lst)
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))

return self.output(g.ndata['h'])
2 changes: 1 addition & 1 deletion python/dgl/nn/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
from .softmax import *
from .factory import *
from .hetero import *
from .utils import Sequential, WeightBasis
from .utils import Sequential, WeightBasis, JumpingKnowledge
from .sparse_emb import NodeEmbedding
121 changes: 121 additions & 0 deletions python/dgl/nn/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,124 @@ def forward(self):
# generate all weights from bases
weight = th.matmul(self.w_comp, self.weight.view(self.num_bases, -1))
return weight.view(self.num_outputs, *self.shape)

class JumpingKnowledge(nn.Module):
r"""
Description
-----------
The Jumping Knowledge aggregation module introduced in `Representation Learning on
Graphs with Jumping Knowledge Networks <https://arxiv.org/abs/1806.03536>`__. It
aggregates the output representations of multiple GNN layers with
**concatenation**
.. math::
h_i^{(1)} \, \Vert \, \ldots \, \Vert \, h_i^{(T)}
or **max pooling**
.. math::
\max \left( h_i^{(1)}, \ldots, h_i^{(T)} \right)
or **LSTM**
.. math::
\sum_{t=1}^T \alpha_i^{(t)} h_i^{(t)}
with attention scores :math:`\alpha_i^{(t)}` obtained from a BiLSTM
Parameters
----------
mode : str
The aggregation to apply. It can be 'cat', 'max', or 'lstm',
corresponding to the equations above in order.
in_feats : int, optional
This argument is only required if :attr:`mode` is ``'lstm'``.
The output representation size of a single GNN layer. Note that
all GNN layers need to have the same output representation size.
num_layers : int, optional
This argument is only required if :attr:`mode` is ``'lstm'``.
The number of GNN layers for output aggregation.
Examples
--------
>>> import dgl
>>> import torch as th
>>> from dgl.nn import JumpingKnowledge
>>> # Output representations of two GNN layers
>>> num_nodes = 3
>>> in_feats = 4
>>> feat_list = [th.zeros(num_nodes, in_feats), th.ones(num_nodes, in_feats)]
>>> # Case1
>>> model = JumpingKnowledge()
>>> model(feat_list).shape
torch.Size([3, 8])
>>> # Case2
>>> model = JumpingKnowledge(mode='max')
>>> model(feat_list).shape
torch.Size([3, 4])
>>> # Case3
>>> model = JumpingKnowledge(mode='max', in_feats=in_feats, num_layers=len(feat_list))
>>> model(feat_list).shape
torch.Size([3, 4])
"""
def __init__(self, mode='cat', in_feats=None, num_layers=None):
super(JumpingKnowledge, self).__init__()
assert mode in ['cat', 'max', 'lstm'], \
"Expect mode to be 'cat', or 'max' or 'lstm', got {}".format(mode)
self.mode = mode

if mode == 'lstm':
assert in_feats is not None, 'in_feats is required for lstm mode'
assert num_layers is not None, 'num_layers is required for lstm mode'
hidden_size = (num_layers * in_feats) // 2
self.lstm = nn.LSTM(in_feats, hidden_size, bidirectional=True, batch_first=True)
self.att = nn.Linear(2 * hidden_size, 1)

def reset_parameters(self):
r"""
Description
-----------
Reinitialize learnable parameters. This comes into effect only for the lstm mode.
"""
if self.mode == 'lstm':
self.lstm.reset_parameters()
self.att.reset_parameters()

def forward(self, feat_list):
r"""
Description
-----------
Aggregate output representations across multiple GNN layers.
Parameters
----------
feat_list : list[Tensor]
feat_list[i] is the output representations of a GNN layer.
Returns
-------
Tensor
The aggregated representations.
"""
if self.mode == 'cat':
return th.cat(feat_list, dim=-1)
elif self.mode == 'max':
return th.stack(feat_list, dim=-1).max(dim=-1)[0]
else:
# LSTM
stacked_feat_list = th.stack(feat_list, dim=1) # (N, num_layers, in_feats)
alpha, _ = self.lstm(stacked_feat_list)
alpha = self.att(alpha).squeeze(-1) # (N, num_layers)
alpha = th.softmax(alpha, dim=-1)
return (stacked_feat_list * alpha.unsqueeze(-1)).sum(dim=1)
20 changes: 20 additions & 0 deletions tests/pytorch/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,6 +1229,26 @@ def forward(self, graph, feat, eweight=None):
explainer = nn.GNNExplainer(model, num_hops=1)
feat_mask, edge_mask = explainer.explain_graph(g, feat)

def test_jumping_knowledge():
ctx = F.ctx()
num_layers = 2
num_nodes = 3
num_feats = 4

feat_list = [th.randn((num_nodes, num_feats)).to(ctx) for _ in range(num_layers)]

model = nn.JumpingKnowledge('cat').to(ctx)
model.reset_parameters()
assert model(feat_list).shape == (num_nodes, num_layers * num_feats)

model = nn.JumpingKnowledge('max').to(ctx)
model.reset_parameters()
assert model(feat_list).shape == (num_nodes, num_feats)

model = nn.JumpingKnowledge('lstm', num_feats, num_layers).to(ctx)
model.reset_parameters()
assert model(feat_list).shape == (num_nodes, num_feats)

if __name__ == '__main__':
test_graph_conv()
test_graph_conv_e_weight()
Expand Down

0 comments on commit 9e7fbf9

Please sign in to comment.