Skip to content

Commit

Permalink
[NN] Enhance EGATConv branch (dmlc#4062)
Browse files Browse the repository at this point in the history
* enhance EGATConv| nfeats as tuples

* egatconv modified for bipartite graphs

* modified docstrings

* added/modified unittests for EGATConv

* Update egatconv.py

* rectified lint errors

Co-authored-by: rijulizer <[email protected]>
Co-authored-by: Mufei Li <[email protected]>
  • Loading branch information
3 people authored Jun 3, 2022
1 parent 92063d8 commit efd909e
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 17 deletions.
83 changes: 68 additions & 15 deletions python/dgl/nn/pytorch/conv/egatconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .... import function as fn
from ...functional import edge_softmax
from ....base import DGLError
from ....utils import expand_as_pair

# pylint: enable=W0235
class EGATConv(nn.Module):
Expand All @@ -27,8 +28,14 @@ class EGATConv(nn.Module):
Parameters
----------
in_node_feats : int
Input node feature size :math:`h_{i}`.
in_node_feats : int, or pair of ints
Input feature size; i.e, the number of dimensions of :math:`h_{i}`.
EGATConv can be applied on homogeneous graph and unidirectional
`bipartite graph <https://docs.dgl.ai/generated/dgl.bipartite.html?highlight=bipartite>`__.
If the layer is to be applied to a unidirectional bipartite graph, ``in_feats``
specifies the input feature size on both the source and destination nodes. If
a scalar is given, the source and destination node feature size would take the
same value.
in_edge_feats : int
Input edge feature size :math:`f_{ij}`.
out_node_feats : int
Expand All @@ -46,10 +53,10 @@ class EGATConv(nn.Module):
>>> import torch as th
>>> from dgl.nn import EGATConv
>>> # Case 1: Homogeneous graph
>>> num_nodes, num_edges = 8, 30
>>> # generate a graph
>>> graph = dgl.rand_graph(num_nodes,num_edges)
>>> node_feats = th.rand((num_nodes, 20))
>>> edge_feats = th.rand((num_edges, 12))
>>> egat = EGATConv(in_node_feats=20,
Expand All @@ -61,8 +68,33 @@ class EGATConv(nn.Module):
>>> new_node_feats, new_edge_feats = egat(graph, node_feats, edge_feats)
>>> new_node_feats.shape, new_edge_feats.shape
torch.Size([8, 3, 15]) torch.Size([30, 3, 10])
"""
>>> # Case 2: Unidirectional bipartite graph
>>> u = [0, 1, 0, 0, 1]
>>> v = [0, 1, 2, 3, 2]
>>> g = dgl.heterograph({('A', 'r', 'B'): (u, v)})
>>> u_feat = th.tensor(np.random.rand(2, 25).astype(np.float32))
>>> v_feat = th.tensor(np.random.rand(4, 30).astype(np.float32))
>>> nfeats = (u_feat,v_feat)
>>> efeats = th.tensor(np.random.rand(5, 15).astype(np.float32))
>>> in_node_feats = (25,30)
>>> in_edge_feats = 15
>>> out_node_feats = 10
>>> out_edge_feats = 5
>>> num_heads = 3
>>> egat_model = EGATConv(in_node_feats,
... in_edge_feats,
... out_node_feats,
... out_edge_feats,
... num_heads,
... bias=True)
>>> #forward pass
>>> new_node_feats,
>>> new_edge_feats,
>>> attentions = egat_model(g, nfeats, efeats, get_attention=True)
>>> new_node_feats.shape, new_edge_feats.shape, attentions.shape
(torch.Size([4, 3, 10]), torch.Size([5, 3, 5]), torch.Size([5, 3, 1]))
"""
def __init__(self,
in_node_feats,
in_edge_feats,
Expand All @@ -73,12 +105,25 @@ def __init__(self,

super().__init__()
self._num_heads = num_heads
self._in_src_node_feats, self._in_dst_node_feats = expand_as_pair(in_node_feats)
self._out_node_feats = out_node_feats
self._out_edge_feats = out_edge_feats
self.fc_node = nn.Linear(in_node_feats, out_node_feats*num_heads, bias=True)
self.fc_ni = nn.Linear(in_node_feats, out_edge_feats*num_heads, bias=False)
if isinstance(in_node_feats, tuple):
self.fc_node_src = nn.Linear(
self._in_src_node_feats, out_node_feats * num_heads, bias=False)
self.fc_ni = nn.Linear(
self._in_src_node_feats, out_edge_feats*num_heads, bias=False)
self.fc_nj = nn.Linear(
self._in_dst_node_feats, out_edge_feats*num_heads, bias=False)
else:
self.fc_node_src = nn.Linear(
self._in_src_node_feats, out_node_feats * num_heads, bias=False)
self.fc_ni = nn.Linear(
self._in_src_node_feats, out_edge_feats*num_heads, bias=False)
self.fc_nj = nn.Linear(
self._in_src_node_feats, out_edge_feats*num_heads, bias=False)

self.fc_fij = nn.Linear(in_edge_feats, out_edge_feats*num_heads, bias=False)
self.fc_nj = nn.Linear(in_node_feats, out_edge_feats*num_heads, bias=False)
self.attn = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_edge_feats)))
if bias:
self.bias = nn.Parameter(th.FloatTensor(size=(num_heads * out_edge_feats,)))
Expand All @@ -91,7 +136,7 @@ def reset_parameters(self):
Reinitialize learnable parameters.
"""
gain = init.calculate_gain('relu')
init.xavier_normal_(self.fc_node.weight, gain=gain)
init.xavier_normal_(self.fc_node_src.weight, gain=gain)
init.xavier_normal_(self.fc_ni.weight, gain=gain)
init.xavier_normal_(self.fc_fij.weight, gain=gain)
init.xavier_normal_(self.fc_nj.weight, gain=gain)
Expand All @@ -106,11 +151,14 @@ def forward(self, graph, nfeats, efeats, get_attention=False):
----------
graph : DGLGraph
The graph.
nfeats : torch.Tensor
The input node feature of shape :math:`(N, D_{in})`
nfeat : torch.Tensor or pair of torch.Tensor
If a torch.Tensor is given, the input feature of shape :math:`(N, D_{in})`
where:
:math:`D_{in}` is size of input node feature,
:math:`N` is the number of nodes.
If a pair of torch.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and
:math:`(N_{out}, D_{in_{dst}})`.
efeats: torch.Tensor
The input edge feature of shape :math:`(E, F_{in})`
where:
Expand Down Expand Up @@ -144,13 +192,18 @@ def forward(self, graph, nfeats, efeats, get_attention=False):
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue.')

# TODO allow node src and dst feats
# calc edge attention
# same trick way as in dgl.nn.pytorch.GATConv, but also includes edge feats
# https://github.com/dmlc/dgl/blob/master/python/dgl/nn/pytorch/conv/gatconv.py
f_ni = self.fc_ni(nfeats)
f_nj = self.fc_nj(nfeats)
if isinstance(nfeats, tuple):
nfeats_src, nfeats_dst = nfeats
else:
nfeats_src = nfeats_dst = nfeats

f_ni = self.fc_ni(nfeats_src)
f_nj = self.fc_nj(nfeats_dst)
f_fij = self.fc_fij(efeats)

graph.srcdata.update({'f_ni': f_ni})
graph.dstdata.update({'f_nj': f_nj})
# add ni, nj factors
Expand All @@ -164,13 +217,13 @@ def forward(self, graph, nfeats, efeats, get_attention=False):
# compute attention factor
e = (f_out * self.attn).sum(dim=-1).unsqueeze(-1)
graph.edata['a'] = edge_softmax(graph, e)
graph.ndata['h_out'] = self.fc_node(nfeats).view(-1, self._num_heads,
graph.srcdata['h_out'] = self.fc_node_src(nfeats_src).view(-1, self._num_heads,
self._out_node_feats)
# calc weighted sum
graph.update_all(fn.u_mul_e('h_out', 'a', 'm'),
fn.sum('m', 'h_out'))

h_out = graph.ndata['h_out'].view(-1, self._num_heads, self._out_node_feats)
h_out = graph.dstdata['h_out'].view(-1, self._num_heads, self._out_node_feats)
if get_attention:
return h_out, f_out, graph.edata.pop('a')
else:
Expand Down
32 changes: 30 additions & 2 deletions tests/pytorch/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,13 +506,41 @@ def test_egat_conv(g, idtype, out_node_feats, out_edge_feats, num_heads):
num_heads=num_heads)
nfeat = F.randn((g.number_of_nodes(), 10))
efeat = F.randn((g.number_of_edges(), 5))

egat = egat.to(ctx)
h, f = egat(g, nfeat, efeat)
h, f, attn = egat(g, nfeat, efeat, True)

th.save(egat, tmp_buffer)

assert h.shape == (g.number_of_nodes(), num_heads, out_node_feats)
assert f.shape == (g.number_of_edges(), num_heads, out_edge_feats)
_, _, attn = egat(g, nfeat, efeat, True)
assert attn.shape == (g.number_of_edges(), num_heads, 1)

@parametrize_idtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_node_feats', [1, 5])
@pytest.mark.parametrize('out_edge_feats', [1, 5])
@pytest.mark.parametrize('num_heads', [1, 4])
def test_egat_conv_bi(g, idtype, out_node_feats, out_edge_feats, num_heads):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
egat = nn.EGATConv(in_node_feats=(10,15),
in_edge_feats=7,
out_node_feats=out_node_feats,
out_edge_feats=out_edge_feats,
num_heads=num_heads)
nfeat = (F.randn((g.number_of_src_nodes(), 10)), F.randn((g.number_of_dst_nodes(), 15)))
efeat = F.randn((g.number_of_edges(), 7))
egat = egat.to(ctx)
h, f = egat(g, nfeat, efeat)

th.save(egat, tmp_buffer)

assert h.shape == (g.number_of_dst_nodes(), num_heads, out_node_feats)
assert f.shape == (g.number_of_edges(), num_heads, out_edge_feats)
_, _, attn = egat(g, nfeat, efeat, True)
assert attn.shape == (g.number_of_edges(), num_heads, 1)

@parametrize_idtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn', 'lstm'])
Expand Down

0 comments on commit efd909e

Please sign in to comment.