Skip to content

Commit

Permalink
[Feature] Edge softmax on a subset of edges in the graph. (dmlc#842)
Browse files Browse the repository at this point in the history
* upd

* add test

* fix

* upd

* merge

* hotfix

* upd

* fix
  • Loading branch information
yzh119 authored Sep 9, 2019
1 parent bcd33e0 commit 6a4b5ae
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 16 deletions.
1 change: 0 additions & 1 deletion examples/mxnet/gat/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
import torch

class EarlyStopping:
def __init__(self, patience=10):
Expand Down
29 changes: 22 additions & 7 deletions python/dgl/nn/mxnet/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import mxnet as mx

from ... import function as fn
from ...base import ALL, is_all

__all__ = ['edge_softmax']

Expand All @@ -24,8 +25,10 @@ class EdgeSoftmax(mx.autograd.Function):
the attention weights are computed with such an edgesoftmax operation.
"""

def __init__(self, g):
def __init__(self, g, eids):
super(EdgeSoftmax, self).__init__()
if not is_all(eids):
g = g.edge_subgraph(eids.astype('int64'))
self.g = g

def forward(self, score):
Expand Down Expand Up @@ -78,7 +81,7 @@ def backward(self, grad_out):
grad_score = g.edata['grad_score'] - g.edata['out']
return grad_score

def edge_softmax(graph, logits):
def edge_softmax(graph, logits, eids=ALL):
r"""Compute edge softmax.
For a node :math:`i`, edge softmax is an operation of computing
Expand All @@ -98,8 +101,11 @@ def edge_softmax(graph, logits):
----------
graph : DGLGraph
The graph to perform edge softmax
logits : torch.Tensor
logits : mxnet.NDArray
The input edge feature
eids : mxnet.NDArray or ALL, optional
Edges on which to apply edge softmax. If ALL, apply edge softmax
on all edges in the graph. Default: ALL.
Returns
-------
Expand All @@ -108,9 +114,10 @@ def edge_softmax(graph, logits):
Notes
-----
* Input shape: :math:`(N, *, 1)` where * means any number of
additional dimensions, :math:`N` is the number of edges.
* Return shape: :math:`(N, *, 1)`
* Input shape: :math:`(E, *, 1)` where * means any number of
additional dimensions, :math:`E` equals the length of eids.
If eids is ALL, :math:`E` equals number of edges in the graph.
* Return shape: :math:`(E, *, 1)`
Examples
--------
Expand Down Expand Up @@ -143,6 +150,14 @@ def edge_softmax(graph, logits):
[0.33333334]
[0.33333334]]
<NDArray 6x1 @cpu(0)>
Apply edge softmax on first 4 edges of g:
>>> edge_softmax(g, edata, nd.array([0,1,2,3], dtype='int64'))
[[1. ]
[0.5]
[1. ]
[0.5]]
<NDArray 4x1 @cpu(0)>
"""
softmax_op = EdgeSoftmax(graph)
softmax_op = EdgeSoftmax(graph, eids)
return softmax_op(logits)
28 changes: 21 additions & 7 deletions python/dgl/nn/pytorch/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch as th

from ... import function as fn
from ...base import ALL, is_all

__all__ = ['edge_softmax']

Expand All @@ -25,7 +26,7 @@ class EdgeSoftmax(th.autograd.Function):
"""

@staticmethod
def forward(ctx, g, score):
def forward(ctx, g, score, eids):
"""Forward function.
Pseudo-code:
Expand All @@ -41,6 +42,8 @@ def forward(ctx, g, score):
"""
# remember to save the graph to backward cache before making it
# a local variable
if not is_all(eids):
g = g.edge_subgraph(eids.long())
ctx.backward_cache = g
g = g.local_var()
g.edata['s'] = score
Expand Down Expand Up @@ -79,10 +82,10 @@ def backward(ctx, grad_out):
g.update_all(fn.copy_e('grad_s', 'm'), fn.sum('m', 'accum'))
g.apply_edges(fn.e_mul_v('out', 'accum', 'out'))
grad_score = g.edata['grad_s'] - g.edata['out']
return None, grad_score
return None, grad_score, None


def edge_softmax(graph, logits):
def edge_softmax(graph, logits, eids=ALL):
r"""Compute edge softmax.
For a node :math:`i`, edge softmax is an operation of computing
Expand All @@ -104,6 +107,9 @@ def edge_softmax(graph, logits):
The graph to perform edge softmax
logits : torch.Tensor
The input edge feature
eids : torch.Tensor or ALL, optional
Edges on which to apply edge softmax. If ALL, apply edge
softmax on all edges in the graph. Default: ALL.
Returns
-------
Expand All @@ -112,9 +118,10 @@ def edge_softmax(graph, logits):
Notes
-----
* Input shape: :math:`(N, *, 1)` where * means any number of
additional dimensions, :math:`N` is the number of edges.
* Return shape: :math:`(N, *, 1)`
* Input shape: :math:`(E, *, 1)` where * means any number of
additional dimensions, :math:`E` equals the length of eids.
If eids is ALL, :math:`E` equals number of edges in the graph.
* Return shape: :math:`(E, *, 1)`
Examples
--------
Expand Down Expand Up @@ -145,5 +152,12 @@ def edge_softmax(graph, logits):
[0.5000],
[0.3333],
[0.3333]])
Apply edge softmax on first 4 edges of g:
>>> edge_softmax(g, edata[:4], th.Tensor([0,1,2,3]))
tensor([[1.0000],
[0.5000],
[1.0000],
[0.5000]])
"""
return EdgeSoftmax.apply(graph, logits)
return EdgeSoftmax.apply(graph, logits, eids)
31 changes: 31 additions & 0 deletions tests/mxnet/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,36 @@ def test_edge_softmax():
assert np.allclose(a.asnumpy(), uniform_attention(g, a.shape).asnumpy(),
1e-4, 1e-4)

def test_partial_edge_softmax():
g = dgl.DGLGraph()
g.add_nodes(30)
# build a complete graph
for i in range(30):
for j in range(30):
g.add_edge(i, j)

score = F.randn((300, 1))
score.attach_grad()
grad = F.randn((300, 1))
import numpy as np
eids = np.random.choice(900, 300, replace=False).astype('int64')
eids = F.zerocopy_from_numpy(eids)
# compute partial edge softmax
with mx.autograd.record():
y_1 = nn.edge_softmax(g, score, eids)
y_1.backward(grad)
grad_1 = score.grad

# compute edge softmax on edge subgraph
subg = g.edge_subgraph(eids)
with mx.autograd.record():
y_2 = nn.edge_softmax(subg, score)
y_2.backward(grad)
grad_2 = score.grad

assert F.allclose(y_1, y_2)
assert F.allclose(grad_1, grad_2)

def test_rgcn():
ctx = F.ctx()
etype = []
Expand Down Expand Up @@ -277,6 +307,7 @@ def test_rgcn():
if __name__ == '__main__':
test_graph_conv()
test_edge_softmax()
test_partial_edge_softmax()
test_set2set()
test_glob_att_pool()
test_simple_pool()
Expand Down
32 changes: 31 additions & 1 deletion tests/pytorch/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,36 @@ def generate_rand_graph(n):
assert len(g.ndata) == 0
assert len(g.edata) == 2
assert F.allclose(a1.grad, a2.grad, rtol=1e-4, atol=1e-4) # Follow tolerance in unittest backend


def test_partial_edge_softmax():
g = dgl.DGLGraph()
g.add_nodes(30)
# build a complete graph
for i in range(30):
for j in range(30):
g.add_edge(i, j)

score = F.randn((300, 1))
score.requires_grad_()
grad = F.randn((300, 1))
import numpy as np
eids = np.random.choice(900, 300, replace=False).astype('int64')
eids = F.zerocopy_from_numpy(eids)
# compute partial edge softmax
y_1 = nn.edge_softmax(g, score, eids)
y_1.backward(grad)
grad_1 = score.grad
score.grad.zero_()
# compute edge softmax on edge subgraph
subg = g.edge_subgraph(eids)
y_2 = nn.edge_softmax(subg, score)
y_2.backward(grad)
grad_2 = score.grad
score.grad.zero_()

assert F.allclose(y_1, y_2)
assert F.allclose(grad_1, grad_2)

def test_rgcn():
ctx = F.ctx()
etype = []
Expand Down Expand Up @@ -570,6 +599,7 @@ def test_dense_cheb_conv():
if __name__ == '__main__':
test_graph_conv()
test_edge_softmax()
test_partial_edge_softmax()
test_set2set()
test_glob_att_pool()
test_simple_pool()
Expand Down

0 comments on commit 6a4b5ae

Please sign in to comment.