Skip to content

Commit

Permalink
[Doc] Re-organize the code for dgl.geometry, and expose it in the doc (
Browse files Browse the repository at this point in the history
…dmlc#2982)

* reorg and expose dgl.geometry

* fix lint

* fix test

* fix
  • Loading branch information
hetong007 authored Jun 7, 2021
1 parent e20d895 commit 972a9f1
Show file tree
Hide file tree
Showing 14 changed files with 122 additions and 260 deletions.
26 changes: 26 additions & 0 deletions docs/source/api/python/dgl.geometry.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
.. _api-geometry:

dgl.geometry
=================================

.. automodule:: dgl.geometry

.. _api-geometry-farthest-point-sampler:

Farthest Point Sampler
-----------

Farthest point sampling is a greedy algorithm that samples from a point cloud
data iteratively. It starts from a random single sample of point. In each iteration,
it samples from the rest points that is the farthest from the set of sampled points.

.. autoclass:: farthest_point_sampler

.. _api-geometry-neighbor-matching:

Neighbor Matching
-----------------------------

Neighbor matching is an important module in the Graclus clustering algorithm.

.. autoclass:: neighbor_matching
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Welcome to Deep Graph Library Tutorials and Documentation
api/python/dgl.DGLGraph
api/python/dgl.distributed
api/python/dgl.function
api/python/dgl.geometry
api/python/nn
api/python/nn.functional
api/python/dgl.ops
Expand Down
10 changes: 5 additions & 5 deletions examples/pytorch/pointcloud/pointnet/pointnet2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import dgl
import dgl.function as fn
from dgl.geometry.pytorch import FarthestPointSampler
from dgl.geometry.pytorch import farthest_point_sampler

'''
Part of the code are adapted from
Expand Down Expand Up @@ -167,7 +167,7 @@ def __init__(self, npoints, batch_size, radius, mlp_sizes, n_neighbor=64,
super(SAModule, self).__init__()
self.group_all = group_all
if not group_all:
self.fps = FarthestPointSampler(npoints)
self.npoints = npoints
self.frnn_graph = FixedRadiusNNGraph(radius, n_neighbor)
self.message = RelativePositionMessage(n_neighbor)
self.conv = PointNetConv(mlp_sizes, batch_size)
Expand All @@ -177,7 +177,7 @@ def forward(self, pos, feat):
if self.group_all:
return self.conv.group_all(pos, feat)

centroids = self.fps(pos)
centroids = farthest_point_sampler(pos, self.npoints)
g = self.frnn_graph(pos, centroids, feat)
g.update_all(self.message, self.conv)

Expand All @@ -197,7 +197,7 @@ def __init__(self, npoints, batch_size, radius_list, n_neighbor_list, mlp_sizes_
self.batch_size = batch_size
self.group_size = len(radius_list)

self.fps = FarthestPointSampler(npoints)
self.npoints = npoints
self.frnn_graph_list = nn.ModuleList()
self.message_list = nn.ModuleList()
self.conv_list = nn.ModuleList()
Expand All @@ -208,7 +208,7 @@ def __init__(self, npoints, batch_size, radius_list, n_neighbor_list, mlp_sizes_
self.conv_list.append(PointNetConv(mlp_sizes_list[i], batch_size))

def forward(self, pos, feat):
centroids = self.fps(pos)
centroids = farthest_point_sampler(pos, self.npoints)
feat_res_list = []

for i in range(self.group_size):
Expand Down
19 changes: 9 additions & 10 deletions python/dgl/geometry/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
"""Package for geometry common components."""
import importlib
import sys
from ..backend import backend_name
"""The ``dgl.geometry`` package contains geometry operations:
* Farthest point sampling for point cloud sampling
def _load_backend(mod_name):
mod = importlib.import_module('.%s' % mod_name, __name__)
thismod = sys.modules[__name__]
for api, obj in mod.__dict__.items():
setattr(thismod, api, obj)
* Neighbor matching module for graclus pooling
_load_backend(backend_name)
.. note::
This package is experimental and the interfaces may be subject
to changes in future releases.
"""
from .fps import *
from .edge_coarsening import *
2 changes: 1 addition & 1 deletion python/dgl/geometry/capi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .. import ndarray as nd


def farthest_point_sampler(data, batch_size, sample_points, dist, start_idx, result):
def _farthest_point_sampler(data, batch_size, sample_points, dist, start_idx, result):
r"""Farthest Point Sampler
Parameters
Expand Down
Original file line number Diff line number Diff line change
@@ -1,37 +1,10 @@
"""Edge coarsening procedure used in Metis and Graclus, for pytorch"""
# pylint: disable=no-member, invalid-name, W0613
import dgl
import torch as th
from ..capi import _neighbor_matching
from .. import remove_self_loop
from .capi import _neighbor_matching

__all__ = ['neighbor_matching']


class NeighborMatchingFn(th.autograd.Function):
r"""
Description
-----------
AutoGrad function for neighbor matching
"""
@staticmethod
def forward(ctx, gidx, num_nodes, e_weights, relabel_idx):
r"""
Description
-----------
Perform forward computation
"""
return _neighbor_matching(gidx, num_nodes, e_weights, relabel_idx)

@staticmethod
def backward(ctx):
r"""
Description
-----------
Perform backward computation
"""
pass # pylint: disable=unnecessary-pass


def neighbor_matching(graph, e_weights=None, relabel_idx=True):
r"""
Description
Expand Down Expand Up @@ -63,14 +36,25 @@ def neighbor_matching(graph, e_weights=None, relabel_idx=True):
relabel_idx : bool, optional
If true, relabel resulting node labels to have consecutive node ids.
default: :obj:`True`
Examples
--------
The following example uses PyTorch backend.
>>> import torch, dgl
>>> from dgl.geometry import neighbor_matching
>>>
>>> g = dgl.graph(([0, 1, 1, 2], [1, 0, 2, 1]))
>>> res = neighbor_matching(g)
tensor([0, 1, 1])
"""
assert graph.is_homogeneous, \
"The graph used in graph node matching must be homogeneous"
if e_weights is not None:
graph.edata['e_weights'] = e_weights
graph = dgl.remove_self_loop(graph)
graph = remove_self_loop(graph)
e_weights = graph.edata['e_weights']
graph.edata.pop('e_weights')
else:
graph = dgl.remove_self_loop(graph)
return NeighborMatchingFn.apply(graph._graph, graph.num_nodes(), e_weights, relabel_idx)
graph = remove_self_loop(graph)
return _neighbor_matching(graph._graph, graph.num_nodes(), e_weights, relabel_idx)
60 changes: 60 additions & 0 deletions python/dgl/geometry/fps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Farthest Point Sampler for pytorch Geometry package"""
#pylint: disable=no-member, invalid-name

from .. import backend as F

from ..base import DGLError
from .capi import _farthest_point_sampler

__all__ = ['farthest_point_sampler']

def farthest_point_sampler(pos, npoints, start_idx=None):
"""Farthest Point Sampler without the need to compute all pairs of distance.
In each batch, the algorithm starts with the sample index specified by ``start_idx``.
Then for each point, we maintain the minimum to-sample distance.
Finally, we pick the point with the maximum such distance.
This process will be repeated for ``sample_points`` - 1 times.
Parameters
----------
pos : tensor
The positional tensor of shape (B, N, C)
npoints : int
The number of points to sample in each batch.
start_idx : int, optional
If given, appoint the index of the starting point,
otherwise randomly select a point as the start point.
(default: None)
Returns
-------
tensor of shape (B, npoints)
The sampled indices in each batch.
Examples
--------
The following exmaple uses PyTorch backend.
>>> import torch
>>> from dgl.geometry import farthest_point_sampler
>>> x = torch.rand((2, 10, 3))
>>> point_idx = farthest_point_sampler(x, 2)
>>> print(point_idx)
tensor([[5, 6],
[7, 8]])
"""
ctx = F.context(pos)
B, N, C = pos.shape
pos = pos.reshape(-1, C)
dist = F.zeros((B * N), dtype=pos.dtype, ctx=ctx)
if start_idx is None:
start_idx = F.randint(shape=(B, ), dtype=F.int64, ctx=ctx, low=0, high=N-1)
else:
if start_idx >= N or start_idx < 0:
raise DGLError("Invalid start_idx, expected 0 <= start_idx < {}, got {}".format(
N, start_idx))
start_idx = F.full_1d((B, ), start_idx, dtype=F.int64, ctx=ctx)
result = F.zeros((npoints * B), dtype=F.int64, ctx=ctx)
_farthest_point_sampler(pos, B, npoints, dist, start_idx, result)
return result.reshape(B, npoints)
3 changes: 0 additions & 3 deletions python/dgl/geometry/mxnet/__init__.py

This file was deleted.

84 changes: 0 additions & 84 deletions python/dgl/geometry/mxnet/edge_coarsening.py

This file was deleted.

58 changes: 0 additions & 58 deletions python/dgl/geometry/mxnet/fps.py

This file was deleted.

3 changes: 0 additions & 3 deletions python/dgl/geometry/pytorch/__init__.py

This file was deleted.

Loading

0 comments on commit 972a9f1

Please sign in to comment.