Skip to content

Commit

Permalink
[Op] Farthest Point Sampler in Cpp and CUDA (dmlc#1630)
Browse files Browse the repository at this point in the history
* working framework without actual algorithm logic

* rename

* fix

* fps passes compilation

* correct algorithm

* add cuda implementation

* update random start

* before refactor

* pass compilation but cuda not working

* working

* code working, will add docstring

* add mxnet support

* update docstring

* update doc and test

* cpplint

* cpcplint

* pylint

* temporary fix

* fix for win64

* fix unitetest

* fix

* fix

* remove comment

* move to geometry package

* remove redundant include

* add docstrings and comments

* add proof

* add validity check
  • Loading branch information
hetong007 authored Jun 21, 2020
1 parent c8b18b7 commit 3d47693
Show file tree
Hide file tree
Showing 14 changed files with 507 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ file(GLOB DGL_SRC
src/kernel/*.cc
src/kernel/cpu/*.cc
src/runtime/*.cc
src/geometry/*.cc
src/geometry/cpu/*.cc
)

file(GLOB_RECURSE DGL_SRC_1
Expand Down
1 change: 1 addition & 0 deletions cmake/modules/CUDA.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ macro(dgl_config_cuda out_variable)
src/kernel/cuda/*.cc
src/kernel/cuda/*.cu
src/runtime/cuda/*.cc
src/geometry/cuda/*.cu
)

dgl_select_nvcc_arch_flags(NVCC_FLAGS_ARCH)
Expand Down
12 changes: 12 additions & 0 deletions python/dgl/geometry/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Package for geometry common components."""
import importlib
import sys
from ..backend import backend_name

def _load_backend(mod_name):
mod = importlib.import_module('.%s' % mod_name, __name__)
thismod = sys.modules[__name__]
for api, obj in mod.__dict__.items():
setattr(thismod, api, obj)

_load_backend(backend_name)
37 changes: 37 additions & 0 deletions python/dgl/geometry/capi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Python interfaces to DGL farthest point sampler."""
from .._ffi.function import _init_api
from .. import backend as F

def farthest_point_sampler(data, batch_size, sample_points, dist, start_idx, result):
"""Farthest Point Sampler
Parameters
----------
data : tensor
A tensor of shape (N, d) where N is the number of points and d is the dimension.
batch_size : int
The number of batches in the ``data``. N should be divisible by batch_size.
sample_points : int
The number of points to sample in each batch.
dist : tensor
Pre-allocated tensor of shape (N, ) for to-sample distance.
start_idx : tensor of int
Pre-allocated tensor of shape (batch_size, ) for the starting sample in each batch.
result : tensor of int
Pre-allocated tensor of shape (sample_points * batch_size, ) for the sampled index.
Returns
-------
No return value. The input variable ``result`` will be overwriten with sampled indices.
"""
assert F.shape(data)[0] >= sample_points * batch_size
assert F.shape(data)[0] % batch_size == 0

_CAPI_FarthestPointSampler(F.zerocopy_to_dgl_ndarray(data),
batch_size, sample_points,
F.zerocopy_to_dgl_ndarray(dist),
F.zerocopy_to_dgl_ndarray(start_idx),
F.zerocopy_to_dgl_ndarray(result))

_init_api('dgl.geometry', __name__)
2 changes: 2 additions & 0 deletions python/dgl/geometry/mxnet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
"""Package for mxnet-specific Geometry modules."""
from .fps import *
47 changes: 47 additions & 0 deletions python/dgl/geometry/mxnet/fps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Farthest Point Sampler for mxnet Geometry package"""
#pylint: disable=no-member, invalid-name

from mxnet import nd
from mxnet.gluon import nn
import numpy as np

from ..capi import farthest_point_sampler

class FarthestPointSampler(nn.Block):
"""Farthest Point Sampler
In each batch, the algorithm starts with the sample index specified by ``start_idx``.
Then for each point, we maintain the minimum to-sample distance.
Finally, we pick the point with the maximum such distance.
This process will be repeated for ``sample_points`` - 1 times.
Parameters
----------
npoints : int
The number of points to sample in each batch.
"""
def __init__(self, npoints):
super(FarthestPointSampler, self).__init__()
self.npoints = npoints

def forward(self, pos):
r"""Memory allocation and sampling
Parameters
----------
pos : tensor
The positional tensor of shape (B, N, C)
Returns
-------
tensor of shape (B, self.npoints)
The sampled indices in each batch.
"""
ctx = pos.context
B, N, C = pos.shape
pos = pos.reshape(-1, C)
dist = nd.zeros((B * N), dtype=pos.dtype, ctx=ctx)
start_idx = nd.random.randint(0, N - 1, (B, ), dtype=np.int, ctx=ctx)
result = nd.zeros((self.npoints * B), dtype=np.int, ctx=ctx)
farthest_point_sampler(pos, B, self.npoints, dist, start_idx, result)
return result.reshape(B, self.npoints)
2 changes: 2 additions & 0 deletions python/dgl/geometry/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
"""Package for mxnet-specific Geometry modules."""
from .fps import *
46 changes: 46 additions & 0 deletions python/dgl/geometry/pytorch/fps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Farthest Point Sampler for pytorch Geometry package"""
#pylint: disable=no-member, invalid-name

import torch as th
from torch import nn

from ..capi import farthest_point_sampler

class FarthestPointSampler(nn.Module):
"""Farthest Point Sampler without the need to compute all pairs of distance.
In each batch, the algorithm starts with the sample index specified by ``start_idx``.
Then for each point, we maintain the minimum to-sample distance.
Finally, we pick the point with the maximum such distance.
This process will be repeated for ``sample_points`` - 1 times.
Parameters
----------
npoints : int
The number of points to sample in each batch.
"""
def __init__(self, npoints):
super(FarthestPointSampler, self).__init__()
self.npoints = npoints

def forward(self, pos):
r"""Memory allocation and sampling
Parameters
----------
pos : tensor
The positional tensor of shape (B, N, C)
Returns
-------
tensor of shape (B, self.npoints)
The sampled indices in each batch.
"""
device = pos.device
B, N, C = pos.shape
pos = pos.reshape(-1, C)
dist = th.zeros((B * N), dtype=pos.dtype, device=device)
start_idx = th.randint(0, N - 1, (B, ), dtype=th.int, device=device)
result = th.zeros((self.npoints * B), dtype=th.int, device=device)
farthest_point_sampler(pos, B, self.npoints, dist, start_idx, result)
return result.reshape(B, self.npoints)
100 changes: 100 additions & 0 deletions src/geometry/cpu/geometry_op_impl.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*!
* Copyright (c) 2019 by Contributors
* \file array/cpu/geometry_op_impl.cc
* \brief Geometry operator CPU implementation
*/
#include <dgl/array.h>
#include <numeric>
#include <vector>

namespace dgl {
using runtime::NDArray;
namespace geometry {
namespace impl {

/*!
* \brief Farthest Point Sampler without the need to compute all pairs of distance.
*
* The input array has shape (N, d), where N is the number of points, and d is the dimension.
* It consists of a (flatten) batch of point clouds.
*
* In each batch, the algorithm starts with the sample index specified by ``start_idx``.
* Then for each point, we maintain the minimum to-sample distance.
* Finally, we pick the point with the maximum such distance.
* This process will be repeated for ``sample_points`` - 1 times.
*/
template <DLDeviceType XPU, typename FloatType, typename IdType>
void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result) {
const FloatType* array_data = static_cast<FloatType*>(array->data);
const int64_t point_in_batch = array->shape[0] / batch_size;
const int64_t dim = array->shape[1];

// distance
FloatType* dist_data = static_cast<FloatType*>(dist->data);

// sample for each cloud in the batch
IdType* start_idx_data = static_cast<IdType*>(start_idx->data);

// return value
IdType* ret_data = static_cast<IdType*>(result->data);

int64_t array_start = 0, ret_start = 0;
// loop for each point cloud sample in this batch
for (auto b = 0; b < batch_size; b++) {
// random init start sample
int64_t sample_idx = (int64_t)start_idx_data[b];
ret_data[ret_start] = (IdType)(sample_idx);

// sample the rest `sample_points - 1` points
for (auto i = 0; i < sample_points - 1; i++) {
// re-init distance and the argmax
int64_t dist_argmax = 0;
FloatType dist_max = -1;

// update the distance
for (auto j = 0; j < point_in_batch; j++) {
// compute the distance on dimensions
FloatType one_dist = 0;
for (auto d = 0; d < dim; d++) {
FloatType tmp = array_data[(array_start + j) * dim + d] -
array_data[(array_start + sample_idx) * dim + d];
one_dist += tmp * tmp;
}

// for each out-of-set point, keep its nearest to-the-set distance
if (i == 0 || dist_data[j] > one_dist) {
dist_data[j] = one_dist;
}
// look for the farthest sample
if (dist_data[j] > dist_max) {
dist_argmax = j;
dist_max = dist_data[j];
}
}
// sample the `dist_argmax`-th point
sample_idx = dist_argmax;
ret_data[ret_start + i + 1] = (IdType)(sample_idx);
}

array_start += point_in_batch;
ret_start += sample_points;
}
}

template void FarthestPointSampler<kDLCPU, float, int32_t>(
NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDLCPU, float, int64_t>(
NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDLCPU, double, int32_t>(
NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDLCPU, double, int64_t>(
NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result);

} // namespace impl
} // namespace geometry
} // namespace dgl
Loading

0 comments on commit 3d47693

Please sign in to comment.