Skip to content

Commit

Permalink
[NN] add multihead in DotGatConv (dmlc#2549)
Browse files Browse the repository at this point in the history
* add multihead in DotGatConv

* Fix spacing issue

* Add Unit test for dotgat

* Modified Unit test for dotgat

* Add transformer like divisor

* Update dotgatconv.py

Co-authored-by: Chen <[email protected]>
Co-authored-by: Zihao Ye <[email protected]>
  • Loading branch information
3 people authored Jan 28, 2021
1 parent 4ca706e commit e4ddafe
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 21 deletions.
66 changes: 45 additions & 21 deletions python/dgl/nn/pytorch/conv/dotgatconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class DotGatConv(nn.Module):
same value.
out_feats : int
Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`.
num_heads : int
Number of head in Multi-Head Attention
allow_zero_in_degree : bool, optional
If there are 0-in-degree nodes in the graph, output for those nodes will be invalid
since no message will be passed to those nodes. This is harmful for some applications
Expand Down Expand Up @@ -75,44 +77,66 @@ class DotGatConv(nn.Module):
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> g = dgl.add_self_loop(g)
>>> feat = th.ones(6, 10)
>>> gatconv = DotGatConv(10, 2)
>>> res = gatconv(g, feat)
>>> dotgatconv = DotGatConv(10, 2, num_heads=3)
>>> res = dotgatconv(g, feat)
>>> res
tensor([[-0.6958, -0.8752],
[-0.6958, -0.8752],
[-0.6958, -0.8752],
[-0.6958, -0.8752],
[-0.6958, -0.8752],
[-0.6958, -0.8752]], grad_fn=<CopyReduceBackward>)
tensor([[[ 3.4570, 1.8634],
[ 1.3805, -0.0762],
[ 1.0390, -1.1479]],
[[ 3.4570, 1.8634],
[ 1.3805, -0.0762],
[ 1.0390, -1.1479]],
[[ 3.4570, 1.8634],
[ 1.3805, -0.0762],
[ 1.0390, -1.1479]],
[[ 3.4570, 1.8634],
[ 1.3805, -0.0762],
[ 1.0390, -1.1479]],
[[ 3.4570, 1.8634],
[ 1.3805, -0.0762],
[ 1.0390, -1.1479]],
[[ 3.4570, 1.8634],
[ 1.3805, -0.0762],
[ 1.0390, -1.1479]]], grad_fn=<BinaryReduceBackward>)
>>> # Case 2: Unidirectional bipartite graph
>>> u = [0, 1, 0, 0, 1]
>>> v = [0, 1, 2, 3, 2]
>>> g = dgl.bipartite((u, v))
>>> u_feat = th.tensor(np.random.rand(2, 5).astype(np.float32))
>>> v_feat = th.tensor(np.random.rand(4, 10).astype(np.float32))
>>> gatconv = DotGatConv((5,10), 2)
>>> res = gatconv(g, (u_feat, v_feat))
>>> dotgatconv = DotGatConv((5,10), 2, 3)
>>> res = dotgatconv(g, (u_feat, v_feat))
>>> res
tensor([[ 0.4718, 0.0864],
[ 0.7099, -0.0335],
[ 0.5869, 0.0284],
[ 0.4718, 0.0864]], grad_fn=<CopyReduceBackward>)
tensor([[[-0.6066, 1.0268],
[-0.5945, -0.4801],
[ 0.1594, 0.3825]],
[[ 0.0268, 1.0783],
[ 0.5041, -1.3025],
[ 0.6568, 0.7048]],
[[-0.2688, 1.0543],
[-0.0315, -0.9016],
[ 0.3943, 0.5347]],
[[-0.6066, 1.0268],
[-0.5945, -0.4801],
[ 0.1594, 0.3825]]], grad_fn=<BinaryReduceBackward>)
"""
def __init__(self,
in_feats,
out_feats,
num_heads,
allow_zero_in_degree=False):
super(DotGatConv, self).__init__()
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats
self._allow_zero_in_degree = allow_zero_in_degree
self._num_heads = num_heads

if isinstance(in_feats, tuple):
self.fc_src = nn.Linear(self._in_src_feats, self._out_feats, bias=False)
self.fc_dst = nn.Linear(self._in_dst_feats, self._out_feats, bias=False)
self.fc_src = nn.Linear(self._in_src_feats, self._out_feats*self._num_heads, bias=False)
self.fc_dst = nn.Linear(self._in_dst_feats, self._out_feats*self._num_heads, bias=False)
else:
self.fc = nn.Linear(self._in_src_feats, self._out_feats, bias=False)
self.fc = nn.Linear(self._in_src_feats, self._out_feats*self._num_heads, bias=False)

def forward(self, graph, feat, get_attention=False):
r"""
Expand Down Expand Up @@ -168,11 +192,11 @@ def forward(self, graph, feat, get_attention=False):
if isinstance(feat, tuple):
h_src = feat[0]
h_dst = feat[1]
feat_src = self.fc_src(h_src)
feat_dst = self.fc_dst(h_dst)
feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats)
feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats)
else:
h_src = feat
feat_src = feat_dst = self.fc(h_src)
feat_src = feat_dst = self.fc(h_src).view(-1, self._num_heads, self._out_feats)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]

Expand All @@ -184,7 +208,7 @@ def forward(self, graph, feat, get_attention=False):
graph.apply_edges(fn.u_dot_v('ft', 'ft', 'a'))

# Step 2. edge softmax to compute attention scores
graph.edata['sa'] = edge_softmax(graph, graph.edata['a'])
graph.edata['sa'] = edge_softmax(graph, graph.edata['a'])/(self._out_feats**0.5)

# Step 3. Broadcast softmax value to each edge, and aggregate dst node
graph.update_all(fn.u_mul_e('ft', 'sa', 'attn'), fn.sum('attn', 'agg_u'))
Expand Down
27 changes: 27 additions & 0 deletions tests/pytorch/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,32 @@ def test_edge_conv_bi(g, idtype):
x0 = F.randn((g.number_of_dst_nodes(), 5))
h1 = edge_conv(g, (h0, x0))
assert h1.shape == (g.number_of_dst_nodes(), 2)

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
def test_dotgat_conv(g, idtype):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
dotgat = nn.DotGatConv(5, 2, 4)
feat = F.randn((g.number_of_nodes(), 5))
dotgat = dotgat.to(ctx)
h = dotgat(g, feat)
assert h.shape == (g.number_of_nodes(), 4, 2)
_, a = dotgat(g, feat, get_attention=True)
assert a.shape == (g.number_of_edges(), 4, 1)

@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
def test_dotgat_conv_bi(g, idtype):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
dotgat = nn.DotGatConv((5, 5), 2, 4)
feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
dotgat = dotgat.to(ctx)
h = dotgat(g, feat)
assert h.shape == (g.number_of_dst_nodes(), 4, 2)
_, a = dotgat(g, feat, get_attention=True)
assert a.shape == (g.number_of_edges(), 4, 1)

def test_dense_cheb_conv():
for k in range(1, 4):
Expand Down Expand Up @@ -1016,6 +1042,7 @@ def forward(self, g, h, arg1=None, *, arg2=None):
test_gated_graph_conv()
test_nn_conv()
test_gmm_conv()
test_dotgat_conv()
test_dense_graph_conv()
test_dense_sage_conv()
test_dense_cheb_conv()
Expand Down

0 comments on commit e4ddafe

Please sign in to comment.