Skip to content

Commit

Permalink
[Doc] new nn api doc (dmlc#2019)
Browse files Browse the repository at this point in the history
* Add dotproduct attention

* [Feature] Add dotproduct attention

* [Feature] Add dotproduct attention

* [Feature] Add dotproduct attention

* [New] Update landing page

* [New] Update landing page

* [New] Update landing page

* [Doc] Update landing page

* [Doc] Update landing page

* [Doc] Update landing page

* [Doc] Update landing page

* [Doc] Update landing page

* [Doc] Update landing page

* [Doc] Update landing page

* [Doc] Update landing page

* [Improvement] use dgl build-in in dotgatconv

* [Doc] review API doc string bottom up

* [Doc] Add doc of input and output features

* [Doc] Update doc string for pooling and transformer.

* [Doc] Reformat doc string and change some wordings.

* [Doc] Doc string refactoring.

* [Doc] Doc string refactoring.

* [Doc] Doc string refactoring.

* [Doc] Doc string refactoring.

* [Doc] Doc string refactoring.

Co-authored-by: Jinjing Zhou <[email protected]>
  • Loading branch information
zhjwy9343 and VoVAllen authored Aug 14, 2020
1 parent f4f7880 commit 2fa7f71
Show file tree
Hide file tree
Showing 5 changed files with 415 additions and 136 deletions.
8 changes: 2 additions & 6 deletions python/dgl/nn/pytorch/conv/dotgatconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,8 @@ def forward(self, graph, feat):
# Step 2. edge softmax to compute attention scores
graph.edata['sa'] = edge_softmax(graph, graph.edata['a'])

# Step 3. Broadcast softmax value to each edge, and then attention is done
graph.apply_edges(lambda edges: {'attn': edges.src['ft'] * \
edges.data['sa'].unsqueeze(dim=0).T})

# Step 4. Aggregate attention to dst,user nodes, so formula 7 is done
graph.update_all(fn.copy_e('attn', 'm'), fn.sum('m', 'agg_u'))
# Step 3. Broadcast softmax value to each edge, and aggregate dst node
graph.update_all(fn.u_mul_e('ft', 'sa', 'attn'), fn.sum('attn', 'agg_u'))

# output results to the destination nodes
rst = graph.dstdata['agg_u']
Expand Down
100 changes: 85 additions & 15 deletions python/dgl/nn/pytorch/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,23 @@ def pairwise_squared_distance(x):


class KNNGraph(nn.Module):
r"""Layer that transforms one point set into a graph, or a batch of
r"""
Description
-----------
Layer that transforms one point set into a graph, or a batch of
point sets with the same number of points into a union of those graphs.
If a batch of point set is provided, then the point :math:`j` in point
set :math:`i` is mapped to graph node ID :math:`i \times M + j`, where
The KNNGraph is implemented in the following steps:
1. Compute an NxN matrix of pairwise distance for all points.
2. Pick the k points with the smallest distance for each point as their k-nearest neighbors.
3. Construct a graph with edges to each point as a node from its k-nearest neighbors.
The overall computational complexity is :math:`O(N^2(logN + D)`.
If a batch of point sets is provided, the point :math:`j` in point
set :math:`i` is mapped to graph node ID: :math:`i \times M + j`, where
:math:`M` is the number of nodes in each point set.
The predecessors of each node are the k-nearest neighbors of the
Expand All @@ -25,15 +37,40 @@ class KNNGraph(nn.Module):
Parameters
----------
k : int
The number of neighbors
The number of neighbors.
Notes
-----
The nearest neighbors found for a node include the node itself.
Examples
--------
The following example uses PyTorch backend.
>>> import torch
>>> from dgl.nn.pytorch.factory import KNNGraph
>>>
>>> kg = KNNGraph(2)
>>> x = torch.tensor([[0,1],
[1,2],
[1,3],
[100, 101],
[101, 102],
[50, 50]])
>>> g = kg(x)
>>> print(g.edges())
(tensor([0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 5]),
tensor([0, 0, 1, 2, 1, 2, 5, 3, 4, 3, 4, 5]))
"""
def __init__(self, k):
super(KNNGraph, self).__init__()
self.k = k

#pylint: disable=invalid-name
def forward(self, x):
"""Forward computation.
"""
Forward computation.
Parameters
----------
Expand All @@ -45,48 +82,81 @@ def forward(self, x):
Returns
-------
DGLGraph
A DGLGraph with no features.
A DGLGraph without features.
"""
return knn_graph(x, self.k)


class SegmentedKNNGraph(nn.Module):
r"""Layer that transforms one point set into a graph, or a batch of
r"""
Description
-----------
Layer that transforms one point set into a graph, or a batch of
point sets with different number of points into a union of those graphs.
If a batch of point set is provided, then the point :math:`j` in point
set :math:`i` is mapped to graph node ID
If a batch of point sets is provided, then the point :math:`j` in the point
set :math:`i` is mapped to graph node ID:
:math:`\sum_{p<i} |V_p| + j`, where :math:`|V_p|` means the number of
points in point set :math:`p`.
points in the point set :math:`p`.
The predecessors of each node are the k-nearest neighbors of the
corresponding point.
Parameters
----------
k : int
The number of neighbors
The number of neighbors.
Notes
-----
The nearest neighbors found for a node include the node itself.
Examples
--------
The following example uses PyTorch backend.
>>> import torch
>>> from dgl.nn.pytorch.factory import SegmentedKNNGraph
>>>
>>> kg = SegmentedKNNGraph(2)
>>> x = torch.tensor([[0,1],
... [1,2],
... [1,3],
... [100, 101],
... [101, 102],
... [50, 50],
... [24,25],
... [25,24]])
>>> g = kg(x, [3,3,2])
>>> print(g.edges())
(tensor([0, 1, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 6, 6, 7, 7]),
tensor([0, 0, 1, 2, 1, 2, 3, 4, 5, 3, 4, 5, 6, 7, 6, 7]))
>>>
"""
def __init__(self, k):
super(SegmentedKNNGraph, self).__init__()
self.k = k

#pylint: disable=invalid-name
def forward(self, x, segs):
"""Forward computation.
r"""Forward computation.
Parameters
----------
x : Tensor
:math:`(M, D)` where :math:`M` means the total number of points
in all point sets.
in all point sets, and :math:`D` means the size of features.
segs : iterable of int
:math:`(N)` integers where :math:`N` means the number of point
sets. The elements must sum up to :math:`M`.
sets. The number of elements must sum up to :math:`M`. And any
:math:`N` should :math:`\ge k`
Returns
-------
DGLGraph
A DGLGraph with no features.
A DGLGraph without features.
"""

return segmented_knn_graph(x, self.k, segs)
Loading

0 comments on commit 2fa7f71

Please sign in to comment.