Skip to content

Commit

Permalink
[performance] Optimize the association order of AXW in GraphSAGE. (dm…
Browse files Browse the repository at this point in the history
…lc#2747)

* upd

* lint

* upd

* upd

* compatibility

* upd

* upd

Co-authored-by: Jinjing Zhou <[email protected]>
  • Loading branch information
yzh119 and VoVAllen authored Mar 18, 2021
1 parent 366cc7e commit edf6446
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 17 deletions.
66 changes: 50 additions & 16 deletions python/dgl/nn/pytorch/conv/sageconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.nn import functional as F

from .... import function as fn
from ....utils import expand_as_pair, check_eq_shape
from ....utils import expand_as_pair, check_eq_shape, dgl_warning


class SAGEConv(nn.Module):
Expand Down Expand Up @@ -119,8 +119,12 @@ def __init__(self,
if aggregator_type == 'lstm':
self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
if aggregator_type != 'gcn':
self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=False)
self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=False)
if bias:
self.bias = nn.parameter.Parameter(torch.zeros(self._out_feats))
else:
self.register_buffer('bias', None)
self.reset_parameters()

def reset_parameters(self):
Expand All @@ -144,6 +148,19 @@ def reset_parameters(self):
nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)

def _compatibility_check(self):
"""Address the backward compatibility issue brought by #2747"""
if not hasattr(self, 'bias'):
dgl_warning("You are loading a GraphSAGE model trained from a old version of DGL, "
"DGL automatically convert it to be compatible with latest version.")
bias = self.fc_neigh.bias
self.fc_neigh.bias = None
if hasattr(self, 'fc_self'):
if bias is not None:
bias = bias + self.fc_self.bias
self.fc_self.bias = None
self.bias = bias

def _lstm_reducer(self, nodes):
"""LSTM reducer
NOTE(zihao): lstm reducer with default schedule (degree bucketing)
Expand Down Expand Up @@ -183,6 +200,7 @@ def forward(self, graph, feat, edge_weight=None):
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature.
"""
self._compatibility_check()
with graph.local_scope():
if isinstance(feat, tuple):
feat_src = self.feat_drop(feat[0])
Expand All @@ -191,11 +209,11 @@ def forward(self, graph, feat, edge_weight=None):
feat_src = feat_dst = self.feat_drop(feat)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
aggregate_fn = fn.copy_src('h', 'm')
msg_fn = fn.copy_src('h', 'm')
if edge_weight is not None:
assert edge_weight.shape[0] == graph.number_of_edges()
graph.edata['_edge_weight'] = edge_weight
aggregate_fn = fn.u_mul_e('h', '_edge_weight', 'm')
msg_fn = fn.u_mul_e('h', '_edge_weight', 'm')

h_self = feat_dst

Expand All @@ -204,34 +222,50 @@ def forward(self, graph, feat, edge_weight=None):
graph.dstdata['neigh'] = torch.zeros(
feat_dst.shape[0], self._in_src_feats).to(feat_dst)

# Determine whether to apply linear transformation before message passing A(XW)
lin_before_mp = self._in_src_feats > self._out_feats

# Message Passing
if self._aggre_type == 'mean':
graph.srcdata['h'] = feat_src
graph.update_all(aggregate_fn, fn.mean('m', 'neigh'))
graph.srcdata['h'] = self.fc_neigh(feat_src) if lin_before_mp else feat_src
graph.update_all(msg_fn, fn.mean('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
if not lin_before_mp:
h_neigh = self.fc_neigh(h_neigh)
elif self._aggre_type == 'gcn':
check_eq_shape(feat)
graph.srcdata['h'] = feat_src
graph.dstdata['h'] = feat_dst # same as above if homogeneous
graph.update_all(aggregate_fn, fn.sum('m', 'neigh'))
graph.srcdata['h'] = self.fc_neigh(feat_src) if lin_before_mp else feat_src
if isinstance(feat, tuple): # heterogeneous
graph.dstdata['h'] = self.fc_neigh(feat_dst) if lin_before_mp else feat_dst
else:
graph.dstdata['h'] = graph.srcdata['h']
graph.update_all(msg_fn, fn.sum('m', 'neigh'))
# divide in_degrees
degs = graph.in_degrees().to(feat_dst)
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
if not lin_before_mp:
h_neigh = self.fc_neigh(h_neigh)
elif self._aggre_type == 'pool':
graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
graph.update_all(aggregate_fn, fn.max('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
graph.update_all(msg_fn, fn.max('m', 'neigh'))
h_neigh = self.fc_neigh(graph.dstdata['neigh'])
elif self._aggre_type == 'lstm':
graph.srcdata['h'] = feat_src
graph.update_all(aggregate_fn, self._lstm_reducer)
h_neigh = graph.dstdata['neigh']
graph.update_all(msg_fn, self._lstm_reducer)
h_neigh = self.fc_neigh(graph.dstdata['neigh'])
else:
raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))

# GraphSAGE GCN does not require fc_self.
if self._aggre_type == 'gcn':
rst = self.fc_neigh(h_neigh)
rst = h_neigh
else:
rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
rst = self.fc_self(h_self) + h_neigh

# bias term
if self.bias is not None:
rst = rst + self.bias

# activation
if self.activation is not None:
rst = self.activation(rst)
Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ def test_dense_sage_conv(g, idtype, out_dim):
sage = nn.SAGEConv(5, out_dim, 'gcn')
dense_sage = nn.DenseSAGEConv(5, out_dim)
dense_sage.fc.weight.data = sage.fc_neigh.weight.data
dense_sage.fc.bias.data = sage.fc_neigh.bias.data
dense_sage.fc.bias.data = sage.bias.data
if len(g.ntypes) == 2:
feat = (
F.randn((g.number_of_src_nodes(), 5)),
Expand Down

0 comments on commit edf6446

Please sign in to comment.