Skip to content

Commit

Permalink
[NN] Grouped reversible residual connections for GNNs (dmlc#3842)
Browse files Browse the repository at this point in the history
* Update

* Fix

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update
  • Loading branch information
mufeili authored Mar 23, 2022
1 parent a3fd059 commit 8005978
Show file tree
Hide file tree
Showing 4 changed files with 254 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/source/api/python/nn-pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Conv Layers
~dgl.nn.pytorch.conv.TWIRLSUnfoldingAndAttention
~dgl.nn.pytorch.conv.GCN2Conv
~dgl.nn.pytorch.conv.HGTConv
~dgl.nn.pytorch.conv.GroupRevRes

Dense Conv Layers
----------------------------------------
Expand Down
3 changes: 2 additions & 1 deletion python/dgl/nn/pytorch/conv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
from .twirlsconv import TWIRLSConv, TWIRLSUnfoldingAndAttention
from .gcn2conv import GCN2Conv
from .hgtconv import HGTConv
from .grouprevres import GroupRevRes

__all__ = ['GraphConv', 'EdgeWeightNorm', 'GATConv', 'GATv2Conv', 'EGATConv', 'TAGConv',
'RelGraphConv', 'SAGEConv', 'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv',
'GMMConv', 'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv',
'DenseChebConv', 'EdgeConv', 'AtomicConv', 'CFConv', 'DotGatConv', 'TWIRLSConv',
'TWIRLSUnfoldingAndAttention', 'GCN2Conv', 'HGTConv']
'TWIRLSUnfoldingAndAttention', 'GCN2Conv', 'HGTConv', 'GroupRevRes']
236 changes: 236 additions & 0 deletions python/dgl/nn/pytorch/conv/grouprevres.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
"""Torch module for grouped reversible residual connections for GNNs"""
# pylint: disable= no-member, arguments-differ, invalid-name, C0116, R1728
from copy import deepcopy
import numpy as np
import torch
import torch.nn as nn

class InvertibleCheckpoint(torch.autograd.Function):
r"""Extension of torch.autograd"""
@staticmethod
def forward(ctx, fn, fn_inverse, num_inputs, *inputs_and_weights):
ctx.fn = fn
ctx.fn_inverse = fn_inverse
ctx.weights = inputs_and_weights[num_inputs:]
inputs = inputs_and_weights[:num_inputs]
ctx.input_requires_grad = []

with torch.no_grad():
# Make a detached copy, which shares the storage
x = []
for element in inputs:
if isinstance(element, torch.Tensor):
x.append(element.detach())
ctx.input_requires_grad.append(element.requires_grad)
else:
x.append(element)
ctx.input_requires_grad.append(None)
# Detach the output, which then allows discarding the intermediary results
outputs = ctx.fn(*x).detach_()

# clear memory of input node features
inputs[1].storage().resize_(0)

# store for backward pass
ctx.inputs = [inputs]
ctx.outputs = [outputs]

return outputs

@staticmethod
def backward(ctx, *grad_outputs):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("InvertibleCheckpoint is not compatible with .grad(), \
please use .backward() if possible")
# retrieve input and output tensor nodes
if len(ctx.outputs) == 0:
raise RuntimeError("Trying to perform backward on the InvertibleCheckpoint \
for more than once.")
inputs = ctx.inputs.pop()
outputs = ctx.outputs.pop()

# reconstruct input node features
with torch.no_grad():
# inputs[0] is DGLGraph and inputs[1] is input node features
inputs_inverted = ctx.fn_inverse(*((inputs[0], outputs)+inputs[2:]))
# clear memory of outputs
outputs.storage().resize_(0)

x = inputs[1]
x.storage().resize_(int(np.prod(x.size())))
x.set_(inputs_inverted)

# compute gradients
with torch.set_grad_enabled(True):
detached_inputs = []
for i, element in enumerate(inputs):
if isinstance(element, torch.Tensor):
element = element.detach()
element.requires_grad = ctx.input_requires_grad[i]
detached_inputs.append(element)

detached_inputs = tuple(detached_inputs)
temp_output = ctx.fn(*detached_inputs)

filtered_detached_inputs = tuple(filter(lambda x: x.requires_grad, detached_inputs))
gradients = torch.autograd.grad(outputs=(temp_output,),
inputs=filtered_detached_inputs + ctx.weights,
grad_outputs=grad_outputs)

input_gradients = []
i = 0
for rg in ctx.input_requires_grad:
if rg:
input_gradients.append(gradients[i])
i += 1
else:
input_gradients.append(None)

gradients = tuple(input_gradients) + gradients[-len(ctx.weights):]

return (None, None, None) + gradients


class GroupRevRes(nn.Module):
r"""Grouped reversible residual connections for GNNs, as introduced in
`Training Graph Neural Networks with 1000 Layers <https://arxiv.org/abs/2106.07476>`__
It uniformly partitions an input node feature :math:`X` into :math:`C` groups
:math:`X_1, X_2, \cdots, X_C` across the channel dimension. Besides, it makes
:math:`C` copies of the input GNN module :math:`f_{w1}, \cdots, f_{wC}`. In the
forward pass, each GNN module only takes the corresponding group of node features.
The output node representations :math:`X^{'}` are computed as follows.
.. math::
X_0^{'} = \sum_{i=2}^{C}X_i
X_i^{'} = f_{wi}(X_{i-1}^{'}, g, U) + X_i, i\in\{1,\cdots,C\}
X^{'} = X_1^{'} \, \Vert \, \ldots \, \Vert \, X_C^{'}
where :math:`g` is the input graph, :math:`U` is arbitrary additional input arguments like
edge features, and :math:`\, \Vert \,` is concatenation.
Parameters
----------
gnn_module : nn.Module
GNN module for message passing. :attr:`GroupRevRes` will clone the module for
:attr:`groups`-1 number of times, yielding :attr:`groups` copies in total.
The input and output node representation size need to be the same. Its forward
function needs to take a DGLGraph and the associated input node features in order,
optionally followed by additional arguments like edge features.
groups : int, optional
The number of groups.
Examples
--------
>>> import dgl
>>> import torch
>>> import torch.nn as nn
>>> from dgl.nn import GraphConv, GroupRevRes
>>> class GNNLayer(nn.Module):
... def __init__(self, feats, dropout=0.2):
... super(GNNLayer, self).__init__()
... # Use BatchNorm and dropout to prevent gradient vanishing
... # In particular if you use a large number of GNN layers
... self.norm = nn.BatchNorm1d(feats)
... self.conv = GraphConv(feats, feats)
... self.dropout = nn.Dropout(dropout)
...
... def forward(self, g, x):
... x = self.norm(x)
... x = self.dropout(x)
... return self.conv(g, x)
>>> num_nodes = 5
>>> num_edges = 20
>>> feats = 32
>>> groups = 2
>>> g = dgl.rand_graph(num_nodes, num_edges)
>>> x = torch.randn(num_nodes, feats)
>>> conv = GNNLayer(feats // groups)
>>> model = GroupRevRes(conv, groups)
>>> out = model(g, x)
"""
def __init__(self, gnn_module, groups=2):
super(GroupRevRes, self).__init__()
self.gnn_modules = nn.ModuleList()
for i in range(groups):
if i == 0:
self.gnn_modules.append(gnn_module)
else:
self.gnn_modules.append(deepcopy(gnn_module))
self.groups = groups

def _forward(self, g, x, *args):
xs = torch.chunk(x, self.groups, dim=-1)

if len(args) == 0:
args_chunks = [()] * self.groups
else:
chunked_args = list(map(lambda arg: torch.chunk(arg, self.groups, dim=-1), args))
args_chunks = list(zip(*chunked_args))
y_in = sum(xs[1:])

ys = []
for i in range(self.groups):
y_in = xs[i] + self.gnn_modules[i](g, y_in, *args_chunks[i])
ys.append(y_in)

out = torch.cat(ys, dim=-1)

return out

def _inverse(self, g, y, *args):
ys = torch.chunk(y, self.groups, dim=-1)

if len(args) == 0:
args_chunks = [()] * self.groups
else:
chunked_args = list(map(lambda arg: torch.chunk(arg, self.groups, dim=-1), args))
args_chunks = list(zip(*chunked_args))

xs = []
for i in range(self.groups-1, -1, -1):
if i != 0:
y_in = ys[i-1]
else:
y_in = sum(xs)

x = ys[i] - self.gnn_modules[i](g, y_in, *args_chunks[i])
xs.append(x)

x = torch.cat(xs[::-1], dim=-1)

return x

def forward(self, g, x, *args):
r"""Apply the GNN module with grouped reversible residual connection.
Parameters
----------
g : DGLGraph
The graph.
x : torch.Tensor
The input feature of shape :math:`(N, D_{in})`, where :math:`D_{in}` is size
of input feature, :math:`N` is the number of nodes.
args
Additional arguments to pass to :attr:`gnn_module`.
Returns
-------
torch.Tensor
The output feature of shape :math:`(N, D_{in})`.
"""
args = (g, x) + args
y = InvertibleCheckpoint.apply(
self._forward,
self._inverse,
len(args),
*(args + tuple([p for p in self.parameters() if p.requires_grad])))

return y
16 changes: 15 additions & 1 deletion tests/pytorch/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,7 +1308,7 @@ def test_hgt(idtype, in_size, num_heads):
etype = th.tensor([i % num_etypes for i in range(g.num_edges())]).to(dev)
ntype = th.tensor([i % num_ntypes for i in range(g.num_nodes())]).to(dev)
x = th.randn(g.num_nodes(), in_size).to(dev)

m = nn.HGTConv(in_size, head_size, num_heads, num_ntypes, num_etypes).to(dev)

y = m(g, x, ntype, etype)
Expand All @@ -1329,3 +1329,17 @@ def test_hgt(idtype, in_size, num_heads):
assert sorted_y.shape == (g.num_nodes(), head_size * num_heads)
# TODO(minjie): enable the following check
#assert th.allclose(y, sorted_y[rev_idx], atol=1e-4, rtol=1e-4)

@parametrize_dtype
def test_group_rev_res(idtype):
dev = F.ctx()

num_nodes = 5
num_edges = 20
feats = 32
groups = 2
g = dgl.rand_graph(num_nodes, num_edges).to(dev)
h = th.randn(num_nodes, feats).to(dev)
conv = nn.GraphConv(feats // groups, feats // groups)
model = nn.GroupRevRes(conv, groups).to(dev)
model(g, h)

0 comments on commit 8005978

Please sign in to comment.