-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Op] Farthest Point Sampler in Cpp and CUDA (dmlc#1630)
* working framework without actual algorithm logic * rename * fix * fps passes compilation * correct algorithm * add cuda implementation * update random start * before refactor * pass compilation but cuda not working * working * code working, will add docstring * add mxnet support * update docstring * update doc and test * cpplint * cpcplint * pylint * temporary fix * fix for win64 * fix unitetest * fix * fix * remove comment * move to geometry package * remove redundant include * add docstrings and comments * add proof * add validity check
- Loading branch information
Showing
14 changed files
with
507 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
"""Package for geometry common components.""" | ||
import importlib | ||
import sys | ||
from ..backend import backend_name | ||
|
||
def _load_backend(mod_name): | ||
mod = importlib.import_module('.%s' % mod_name, __name__) | ||
thismod = sys.modules[__name__] | ||
for api, obj in mod.__dict__.items(): | ||
setattr(thismod, api, obj) | ||
|
||
_load_backend(backend_name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
"""Python interfaces to DGL farthest point sampler.""" | ||
from .._ffi.function import _init_api | ||
from .. import backend as F | ||
|
||
def farthest_point_sampler(data, batch_size, sample_points, dist, start_idx, result): | ||
"""Farthest Point Sampler | ||
Parameters | ||
---------- | ||
data : tensor | ||
A tensor of shape (N, d) where N is the number of points and d is the dimension. | ||
batch_size : int | ||
The number of batches in the ``data``. N should be divisible by batch_size. | ||
sample_points : int | ||
The number of points to sample in each batch. | ||
dist : tensor | ||
Pre-allocated tensor of shape (N, ) for to-sample distance. | ||
start_idx : tensor of int | ||
Pre-allocated tensor of shape (batch_size, ) for the starting sample in each batch. | ||
result : tensor of int | ||
Pre-allocated tensor of shape (sample_points * batch_size, ) for the sampled index. | ||
Returns | ||
------- | ||
No return value. The input variable ``result`` will be overwriten with sampled indices. | ||
""" | ||
assert F.shape(data)[0] >= sample_points * batch_size | ||
assert F.shape(data)[0] % batch_size == 0 | ||
|
||
_CAPI_FarthestPointSampler(F.zerocopy_to_dgl_ndarray(data), | ||
batch_size, sample_points, | ||
F.zerocopy_to_dgl_ndarray(dist), | ||
F.zerocopy_to_dgl_ndarray(start_idx), | ||
F.zerocopy_to_dgl_ndarray(result)) | ||
|
||
_init_api('dgl.geometry', __name__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
"""Package for mxnet-specific Geometry modules.""" | ||
from .fps import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
"""Farthest Point Sampler for mxnet Geometry package""" | ||
#pylint: disable=no-member, invalid-name | ||
|
||
from mxnet import nd | ||
from mxnet.gluon import nn | ||
import numpy as np | ||
|
||
from ..capi import farthest_point_sampler | ||
|
||
class FarthestPointSampler(nn.Block): | ||
"""Farthest Point Sampler | ||
In each batch, the algorithm starts with the sample index specified by ``start_idx``. | ||
Then for each point, we maintain the minimum to-sample distance. | ||
Finally, we pick the point with the maximum such distance. | ||
This process will be repeated for ``sample_points`` - 1 times. | ||
Parameters | ||
---------- | ||
npoints : int | ||
The number of points to sample in each batch. | ||
""" | ||
def __init__(self, npoints): | ||
super(FarthestPointSampler, self).__init__() | ||
self.npoints = npoints | ||
|
||
def forward(self, pos): | ||
r"""Memory allocation and sampling | ||
Parameters | ||
---------- | ||
pos : tensor | ||
The positional tensor of shape (B, N, C) | ||
Returns | ||
------- | ||
tensor of shape (B, self.npoints) | ||
The sampled indices in each batch. | ||
""" | ||
ctx = pos.context | ||
B, N, C = pos.shape | ||
pos = pos.reshape(-1, C) | ||
dist = nd.zeros((B * N), dtype=pos.dtype, ctx=ctx) | ||
start_idx = nd.random.randint(0, N - 1, (B, ), dtype=np.int, ctx=ctx) | ||
result = nd.zeros((self.npoints * B), dtype=np.int, ctx=ctx) | ||
farthest_point_sampler(pos, B, self.npoints, dist, start_idx, result) | ||
return result.reshape(B, self.npoints) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
"""Package for mxnet-specific Geometry modules.""" | ||
from .fps import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
"""Farthest Point Sampler for pytorch Geometry package""" | ||
#pylint: disable=no-member, invalid-name | ||
|
||
import torch as th | ||
from torch import nn | ||
|
||
from ..capi import farthest_point_sampler | ||
|
||
class FarthestPointSampler(nn.Module): | ||
"""Farthest Point Sampler without the need to compute all pairs of distance. | ||
In each batch, the algorithm starts with the sample index specified by ``start_idx``. | ||
Then for each point, we maintain the minimum to-sample distance. | ||
Finally, we pick the point with the maximum such distance. | ||
This process will be repeated for ``sample_points`` - 1 times. | ||
Parameters | ||
---------- | ||
npoints : int | ||
The number of points to sample in each batch. | ||
""" | ||
def __init__(self, npoints): | ||
super(FarthestPointSampler, self).__init__() | ||
self.npoints = npoints | ||
|
||
def forward(self, pos): | ||
r"""Memory allocation and sampling | ||
Parameters | ||
---------- | ||
pos : tensor | ||
The positional tensor of shape (B, N, C) | ||
Returns | ||
------- | ||
tensor of shape (B, self.npoints) | ||
The sampled indices in each batch. | ||
""" | ||
device = pos.device | ||
B, N, C = pos.shape | ||
pos = pos.reshape(-1, C) | ||
dist = th.zeros((B * N), dtype=pos.dtype, device=device) | ||
start_idx = th.randint(0, N - 1, (B, ), dtype=th.int, device=device) | ||
result = th.zeros((self.npoints * B), dtype=th.int, device=device) | ||
farthest_point_sampler(pos, B, self.npoints, dist, start_idx, result) | ||
return result.reshape(B, self.npoints) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
/*! | ||
* Copyright (c) 2019 by Contributors | ||
* \file array/cpu/geometry_op_impl.cc | ||
* \brief Geometry operator CPU implementation | ||
*/ | ||
#include <dgl/array.h> | ||
#include <numeric> | ||
#include <vector> | ||
|
||
namespace dgl { | ||
using runtime::NDArray; | ||
namespace geometry { | ||
namespace impl { | ||
|
||
/*! | ||
* \brief Farthest Point Sampler without the need to compute all pairs of distance. | ||
* | ||
* The input array has shape (N, d), where N is the number of points, and d is the dimension. | ||
* It consists of a (flatten) batch of point clouds. | ||
* | ||
* In each batch, the algorithm starts with the sample index specified by ``start_idx``. | ||
* Then for each point, we maintain the minimum to-sample distance. | ||
* Finally, we pick the point with the maximum such distance. | ||
* This process will be repeated for ``sample_points`` - 1 times. | ||
*/ | ||
template <DLDeviceType XPU, typename FloatType, typename IdType> | ||
void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_points, | ||
NDArray dist, IdArray start_idx, IdArray result) { | ||
const FloatType* array_data = static_cast<FloatType*>(array->data); | ||
const int64_t point_in_batch = array->shape[0] / batch_size; | ||
const int64_t dim = array->shape[1]; | ||
|
||
// distance | ||
FloatType* dist_data = static_cast<FloatType*>(dist->data); | ||
|
||
// sample for each cloud in the batch | ||
IdType* start_idx_data = static_cast<IdType*>(start_idx->data); | ||
|
||
// return value | ||
IdType* ret_data = static_cast<IdType*>(result->data); | ||
|
||
int64_t array_start = 0, ret_start = 0; | ||
// loop for each point cloud sample in this batch | ||
for (auto b = 0; b < batch_size; b++) { | ||
// random init start sample | ||
int64_t sample_idx = (int64_t)start_idx_data[b]; | ||
ret_data[ret_start] = (IdType)(sample_idx); | ||
|
||
// sample the rest `sample_points - 1` points | ||
for (auto i = 0; i < sample_points - 1; i++) { | ||
// re-init distance and the argmax | ||
int64_t dist_argmax = 0; | ||
FloatType dist_max = -1; | ||
|
||
// update the distance | ||
for (auto j = 0; j < point_in_batch; j++) { | ||
// compute the distance on dimensions | ||
FloatType one_dist = 0; | ||
for (auto d = 0; d < dim; d++) { | ||
FloatType tmp = array_data[(array_start + j) * dim + d] - | ||
array_data[(array_start + sample_idx) * dim + d]; | ||
one_dist += tmp * tmp; | ||
} | ||
|
||
// for each out-of-set point, keep its nearest to-the-set distance | ||
if (i == 0 || dist_data[j] > one_dist) { | ||
dist_data[j] = one_dist; | ||
} | ||
// look for the farthest sample | ||
if (dist_data[j] > dist_max) { | ||
dist_argmax = j; | ||
dist_max = dist_data[j]; | ||
} | ||
} | ||
// sample the `dist_argmax`-th point | ||
sample_idx = dist_argmax; | ||
ret_data[ret_start + i + 1] = (IdType)(sample_idx); | ||
} | ||
|
||
array_start += point_in_batch; | ||
ret_start += sample_points; | ||
} | ||
} | ||
|
||
template void FarthestPointSampler<kDLCPU, float, int32_t>( | ||
NDArray array, int64_t batch_size, int64_t sample_points, | ||
NDArray dist, IdArray start_idx, IdArray result); | ||
template void FarthestPointSampler<kDLCPU, float, int64_t>( | ||
NDArray array, int64_t batch_size, int64_t sample_points, | ||
NDArray dist, IdArray start_idx, IdArray result); | ||
template void FarthestPointSampler<kDLCPU, double, int32_t>( | ||
NDArray array, int64_t batch_size, int64_t sample_points, | ||
NDArray dist, IdArray start_idx, IdArray result); | ||
template void FarthestPointSampler<kDLCPU, double, int64_t>( | ||
NDArray array, int64_t batch_size, int64_t sample_points, | ||
NDArray dist, IdArray start_idx, IdArray result); | ||
|
||
} // namespace impl | ||
} // namespace geometry | ||
} // namespace dgl |
Oops, something went wrong.