-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[NN] Grouped reversible residual connections for GNNs (dmlc#3842)
* Update * Fix * Update * Update * Update * Update * Update * Update * Update * Update
- Loading branch information
Showing
4 changed files
with
254 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,236 @@ | ||
"""Torch module for grouped reversible residual connections for GNNs""" | ||
# pylint: disable= no-member, arguments-differ, invalid-name, C0116, R1728 | ||
from copy import deepcopy | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
|
||
class InvertibleCheckpoint(torch.autograd.Function): | ||
r"""Extension of torch.autograd""" | ||
@staticmethod | ||
def forward(ctx, fn, fn_inverse, num_inputs, *inputs_and_weights): | ||
ctx.fn = fn | ||
ctx.fn_inverse = fn_inverse | ||
ctx.weights = inputs_and_weights[num_inputs:] | ||
inputs = inputs_and_weights[:num_inputs] | ||
ctx.input_requires_grad = [] | ||
|
||
with torch.no_grad(): | ||
# Make a detached copy, which shares the storage | ||
x = [] | ||
for element in inputs: | ||
if isinstance(element, torch.Tensor): | ||
x.append(element.detach()) | ||
ctx.input_requires_grad.append(element.requires_grad) | ||
else: | ||
x.append(element) | ||
ctx.input_requires_grad.append(None) | ||
# Detach the output, which then allows discarding the intermediary results | ||
outputs = ctx.fn(*x).detach_() | ||
|
||
# clear memory of input node features | ||
inputs[1].storage().resize_(0) | ||
|
||
# store for backward pass | ||
ctx.inputs = [inputs] | ||
ctx.outputs = [outputs] | ||
|
||
return outputs | ||
|
||
@staticmethod | ||
def backward(ctx, *grad_outputs): | ||
if not torch.autograd._is_checkpoint_valid(): | ||
raise RuntimeError("InvertibleCheckpoint is not compatible with .grad(), \ | ||
please use .backward() if possible") | ||
# retrieve input and output tensor nodes | ||
if len(ctx.outputs) == 0: | ||
raise RuntimeError("Trying to perform backward on the InvertibleCheckpoint \ | ||
for more than once.") | ||
inputs = ctx.inputs.pop() | ||
outputs = ctx.outputs.pop() | ||
|
||
# reconstruct input node features | ||
with torch.no_grad(): | ||
# inputs[0] is DGLGraph and inputs[1] is input node features | ||
inputs_inverted = ctx.fn_inverse(*((inputs[0], outputs)+inputs[2:])) | ||
# clear memory of outputs | ||
outputs.storage().resize_(0) | ||
|
||
x = inputs[1] | ||
x.storage().resize_(int(np.prod(x.size()))) | ||
x.set_(inputs_inverted) | ||
|
||
# compute gradients | ||
with torch.set_grad_enabled(True): | ||
detached_inputs = [] | ||
for i, element in enumerate(inputs): | ||
if isinstance(element, torch.Tensor): | ||
element = element.detach() | ||
element.requires_grad = ctx.input_requires_grad[i] | ||
detached_inputs.append(element) | ||
|
||
detached_inputs = tuple(detached_inputs) | ||
temp_output = ctx.fn(*detached_inputs) | ||
|
||
filtered_detached_inputs = tuple(filter(lambda x: x.requires_grad, detached_inputs)) | ||
gradients = torch.autograd.grad(outputs=(temp_output,), | ||
inputs=filtered_detached_inputs + ctx.weights, | ||
grad_outputs=grad_outputs) | ||
|
||
input_gradients = [] | ||
i = 0 | ||
for rg in ctx.input_requires_grad: | ||
if rg: | ||
input_gradients.append(gradients[i]) | ||
i += 1 | ||
else: | ||
input_gradients.append(None) | ||
|
||
gradients = tuple(input_gradients) + gradients[-len(ctx.weights):] | ||
|
||
return (None, None, None) + gradients | ||
|
||
|
||
class GroupRevRes(nn.Module): | ||
r"""Grouped reversible residual connections for GNNs, as introduced in | ||
`Training Graph Neural Networks with 1000 Layers <https://arxiv.org/abs/2106.07476>`__ | ||
It uniformly partitions an input node feature :math:`X` into :math:`C` groups | ||
:math:`X_1, X_2, \cdots, X_C` across the channel dimension. Besides, it makes | ||
:math:`C` copies of the input GNN module :math:`f_{w1}, \cdots, f_{wC}`. In the | ||
forward pass, each GNN module only takes the corresponding group of node features. | ||
The output node representations :math:`X^{'}` are computed as follows. | ||
.. math:: | ||
X_0^{'} = \sum_{i=2}^{C}X_i | ||
X_i^{'} = f_{wi}(X_{i-1}^{'}, g, U) + X_i, i\in\{1,\cdots,C\} | ||
X^{'} = X_1^{'} \, \Vert \, \ldots \, \Vert \, X_C^{'} | ||
where :math:`g` is the input graph, :math:`U` is arbitrary additional input arguments like | ||
edge features, and :math:`\, \Vert \,` is concatenation. | ||
Parameters | ||
---------- | ||
gnn_module : nn.Module | ||
GNN module for message passing. :attr:`GroupRevRes` will clone the module for | ||
:attr:`groups`-1 number of times, yielding :attr:`groups` copies in total. | ||
The input and output node representation size need to be the same. Its forward | ||
function needs to take a DGLGraph and the associated input node features in order, | ||
optionally followed by additional arguments like edge features. | ||
groups : int, optional | ||
The number of groups. | ||
Examples | ||
-------- | ||
>>> import dgl | ||
>>> import torch | ||
>>> import torch.nn as nn | ||
>>> from dgl.nn import GraphConv, GroupRevRes | ||
>>> class GNNLayer(nn.Module): | ||
... def __init__(self, feats, dropout=0.2): | ||
... super(GNNLayer, self).__init__() | ||
... # Use BatchNorm and dropout to prevent gradient vanishing | ||
... # In particular if you use a large number of GNN layers | ||
... self.norm = nn.BatchNorm1d(feats) | ||
... self.conv = GraphConv(feats, feats) | ||
... self.dropout = nn.Dropout(dropout) | ||
... | ||
... def forward(self, g, x): | ||
... x = self.norm(x) | ||
... x = self.dropout(x) | ||
... return self.conv(g, x) | ||
>>> num_nodes = 5 | ||
>>> num_edges = 20 | ||
>>> feats = 32 | ||
>>> groups = 2 | ||
>>> g = dgl.rand_graph(num_nodes, num_edges) | ||
>>> x = torch.randn(num_nodes, feats) | ||
>>> conv = GNNLayer(feats // groups) | ||
>>> model = GroupRevRes(conv, groups) | ||
>>> out = model(g, x) | ||
""" | ||
def __init__(self, gnn_module, groups=2): | ||
super(GroupRevRes, self).__init__() | ||
self.gnn_modules = nn.ModuleList() | ||
for i in range(groups): | ||
if i == 0: | ||
self.gnn_modules.append(gnn_module) | ||
else: | ||
self.gnn_modules.append(deepcopy(gnn_module)) | ||
self.groups = groups | ||
|
||
def _forward(self, g, x, *args): | ||
xs = torch.chunk(x, self.groups, dim=-1) | ||
|
||
if len(args) == 0: | ||
args_chunks = [()] * self.groups | ||
else: | ||
chunked_args = list(map(lambda arg: torch.chunk(arg, self.groups, dim=-1), args)) | ||
args_chunks = list(zip(*chunked_args)) | ||
y_in = sum(xs[1:]) | ||
|
||
ys = [] | ||
for i in range(self.groups): | ||
y_in = xs[i] + self.gnn_modules[i](g, y_in, *args_chunks[i]) | ||
ys.append(y_in) | ||
|
||
out = torch.cat(ys, dim=-1) | ||
|
||
return out | ||
|
||
def _inverse(self, g, y, *args): | ||
ys = torch.chunk(y, self.groups, dim=-1) | ||
|
||
if len(args) == 0: | ||
args_chunks = [()] * self.groups | ||
else: | ||
chunked_args = list(map(lambda arg: torch.chunk(arg, self.groups, dim=-1), args)) | ||
args_chunks = list(zip(*chunked_args)) | ||
|
||
xs = [] | ||
for i in range(self.groups-1, -1, -1): | ||
if i != 0: | ||
y_in = ys[i-1] | ||
else: | ||
y_in = sum(xs) | ||
|
||
x = ys[i] - self.gnn_modules[i](g, y_in, *args_chunks[i]) | ||
xs.append(x) | ||
|
||
x = torch.cat(xs[::-1], dim=-1) | ||
|
||
return x | ||
|
||
def forward(self, g, x, *args): | ||
r"""Apply the GNN module with grouped reversible residual connection. | ||
Parameters | ||
---------- | ||
g : DGLGraph | ||
The graph. | ||
x : torch.Tensor | ||
The input feature of shape :math:`(N, D_{in})`, where :math:`D_{in}` is size | ||
of input feature, :math:`N` is the number of nodes. | ||
args | ||
Additional arguments to pass to :attr:`gnn_module`. | ||
Returns | ||
------- | ||
torch.Tensor | ||
The output feature of shape :math:`(N, D_{in})`. | ||
""" | ||
args = (g, x) + args | ||
y = InvertibleCheckpoint.apply( | ||
self._forward, | ||
self._inverse, | ||
len(args), | ||
*(args + tuple([p for p in self.parameters() if p.requires_grad]))) | ||
|
||
return y |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters