Skip to content

Commit

Permalink
upd (dmlc#741)
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 authored Aug 6, 2019
1 parent 5d3f470 commit 742d79a
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 40 deletions.
6 changes: 2 additions & 4 deletions docs/source/api/python/nn.mxnet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@ dgl.nn.mxnet.glob

.. automodule:: dgl.nn.mxnet.glob
:members:
:show-inheritance:

dgl.nn.mxnet.softmax
--------------------

.. automodule:: dgl.nn.mxnet.softmax

.. autoclass:: dgl.nn.mxnet.softmax.EdgeSoftmax
:members: forward
:show-inheritance:
:members: edge_softmax
6 changes: 1 addition & 5 deletions docs/source/api/python/nn.pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ dgl.nn.pytorch.conv

dgl.nn.pytorch.glob
-------------------

.. automodule:: dgl.nn.pytorch.glob

.. autoclass:: dgl.nn.pytorch.glob.SumPooling
Expand Down Expand Up @@ -53,7 +52,4 @@ dgl.nn.pytorch.softmax
----------------------

.. automodule:: dgl.nn.pytorch.softmax

.. autoclass:: dgl.nn.pytorch.softmax.EdgeSoftmax
:members: forward
:show-inheritance:
:members: edge_softmax
76 changes: 61 additions & 15 deletions python/dgl/nn/mxnet/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,15 @@ def forward(self, score):
"""Forward function.
Pseudo-code:
score = dgl.EData(g, score)
score_max = score.dst_max() # of type dgl.NData
score = score - score_max # edge_sub_dst, ret dgl.EData
score_sum = score.dst_sum() # of type dgl.NData
out = score / score_sum # edge_div_dst, ret dgl.EData
return out.data
.. code:: python
score = dgl.EData(g, score)
score_max = score.dst_max() # of type dgl.NData
score = score - score_max # edge_sub_dst, ret dgl.EData
score_sum = score.dst_sum() # of type dgl.NData
out = score / score_sum # edge_div_dst, ret dgl.EData
return out.data
"""
g = self.g.local_var()
g.edata['s'] = score
Expand All @@ -54,12 +57,15 @@ def backward(self, grad_out):
"""Backward function.
Pseudo-code:
g, out = ctx.backward_cache
grad_out = dgl.EData(g, grad_out)
out = dgl.EData(g, out)
sds = out * grad_out # type dgl.EData
sds_sum = sds.dst_sum() # type dgl.NData
grad_score = sds - sds * sds_sum # multiple expressions
.. code:: python
g, out = ctx.backward_cache
grad_out = dgl.EData(g, grad_out)
out = dgl.EData(g, out)
sds = out * grad_out # type dgl.EData
sds_sum = sds.dst_sum() # type dgl.NData
grad_score = sds - sds * sds_sum # multiple expressions
"""
g = self.g.local_var()
out, = self.saved_tensors # pylint: disable=access-member-before-definition, unpacking-non-sequence
Expand All @@ -75,6 +81,19 @@ def backward(self, grad_out):
def edge_softmax(graph, logits):
r"""Compute edge softmax.
For a node :math:`i`, edge softmax is an operation of computing
.. math::
a_{ij} = \frac{\exp(z_{ij})}{\sum_{j\in\mathcal{N}(i)}\exp(z_{ij})}
where :math:`z_{ij}` is a signal of edge :math:`j\rightarrow i`, also
called logits in the context of softmax. :math:`\mathcal{N}(i)` is
the set of nodes that have an edge to :math:`i`.
An example of using edge softmax is in
`Graph Attention Network <https://arxiv.org/pdf/1710.10903.pdf>`__ where
the attention weights are computed with such an edge softmax operation.
Parameters
----------
graph : DGLGraph
Expand All @@ -90,13 +109,40 @@ def edge_softmax(graph, logits):
Notes
-----
* Input shape: :math:`(N, *, 1)` where * means any number of
additional dimensions, :math:`N` is the number of edges.
additional dimensions, :math:`N` is the number of edges.
* Return shape: :math:`(N, *, 1)`
Examples
--------
>>> import dgl.function as fn
>>> attention = EdgeSoftmax(logits, graph)
>>> from dgl.nn.mxnet.softmax import edge_softmax
>>> import dgl
>>> from mxnet import nd
Create a :code:`DGLGraph` object and initialize its edge features.
>>> g = dgl.DGLGraph()
>>> g.add_nodes(3)
>>> g.add_edges([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2])
>>> edata = nd.ones((6, 1))
>>> edata
[[1.]
[1.]
[1.]
[1.]
[1.]
[1.]]
<NDArray 6x1 @cpu(0)>
Apply edge softmax on g:
>>> edge_softmax(g, edata)
[[1. ]
[0.5 ]
[0.33333334]
[0.5 ]
[0.33333334]
[0.33333334]]
<NDArray 6x1 @cpu(0)>
"""
softmax_op = EdgeSoftmax(graph)
return softmax_op(logits)
76 changes: 60 additions & 16 deletions python/dgl/nn/pytorch/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,15 @@ def forward(ctx, g, score):
"""Forward function.
Pseudo-code:
score = dgl.EData(g, score)
score_max = score.dst_max() # of type dgl.NData
score = score - score_max # edge_sub_dst, ret dgl.EData
score_sum = score.dst_sum() # of type dgl.NData
out = score / score_sum # edge_div_dst, ret dgl.EData
return out.data
.. code:: python
score = dgl.EData(g, score)
score_max = score.dst_max() # of type dgl.NData
score = score - score_max # edge_sub_dst, ret dgl.EData
score_sum = score.dst_sum() # of type dgl.NData
out = score / score_sum # edge_div_dst, ret dgl.EData
return out.data
"""
# remember to save the graph to backward cache before making it
# a local variable
Expand All @@ -55,13 +58,16 @@ def backward(ctx, grad_out):
"""Backward function.
Pseudo-code:
g, out = ctx.backward_cache
grad_out = dgl.EData(g, grad_out)
out = dgl.EData(g, out)
sds = out * grad_out # type dgl.EData
sds_sum = sds.dst_sum() # type dgl.NData
grad_score = sds - sds * sds_sum # multiple expressions
return grad_score.data
.. code:: python
g, out = ctx.backward_cache
grad_out = dgl.EData(g, grad_out)
out = dgl.EData(g, out)
sds = out * grad_out # type dgl.EData
sds_sum = sds.dst_sum() # type dgl.NData
grad_score = sds - sds * sds_sum # multiple expressions
return grad_score.data
"""
g = ctx.backward_cache
g = g.local_var()
Expand All @@ -79,6 +85,19 @@ def backward(ctx, grad_out):
def edge_softmax(graph, logits):
r"""Compute edge softmax.
For a node :math:`i`, edge softmax is an operation of computing
.. math::
a_{ij} = \frac{\exp(z_{ij})}{\sum_{j\in\mathcal{N}(i)}\exp(z_{ij})}
where :math:`z_{ij}` is a signal of edge :math:`j\rightarrow i`, also
called logits in the context of softmax. :math:`\mathcal{N}(i)` is
the set of nodes that have an edge to :math:`i`.
An example of using edge softmax is in
`Graph Attention Network <https://arxiv.org/pdf/1710.10903.pdf>`__ where
the attention weights are computed with such an edge softmax operation.
Parameters
----------
graph : DGLGraph
Expand All @@ -94,12 +113,37 @@ def edge_softmax(graph, logits):
Notes
-----
* Input shape: :math:`(N, *, 1)` where * means any number of
additional dimensions, :math:`N` is the number of edges.
additional dimensions, :math:`N` is the number of edges.
* Return shape: :math:`(N, *, 1)`
Examples
--------
>>> import dgl.function as fn
>>> attention = EdgeSoftmax(logits, graph)
>>> from dgl.nn.pytorch.softmax import edge_softmax
>>> import dgl
>>> import torch as th
Create a :code:`DGLGraph` object and initialize its edge features.
>>> g = dgl.DGLGraph()
>>> g.add_nodes(3)
>>> g.add_edges([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2])
>>> edata = th.ones(6, 1).float()
>>> edata
tensor([[1.],
[1.],
[1.],
[1.],
[1.],
[1.]])
Apply edge softmax on g:
>>> edge_softmax(g, edata)
tensor([[1.0000],
[0.5000],
[0.3333],
[0.5000],
[0.3333],
[0.3333]])
"""
return EdgeSoftmax.apply(graph, logits)

0 comments on commit 742d79a

Please sign in to comment.