From 3d47693b1f919b5aeedeef6cf3f1e3848c8fe088 Mon Sep 17 00:00:00 2001 From: Tong He Date: Mon, 22 Jun 2020 00:52:20 +0800 Subject: [PATCH] [Op] Farthest Point Sampler in Cpp and CUDA (#1630) * working framework without actual algorithm logic * rename * fix * fps passes compilation * correct algorithm * add cuda implementation * update random start * before refactor * pass compilation but cuda not working * working * code working, will add docstring * add mxnet support * update docstring * update doc and test * cpplint * cpcplint * pylint * temporary fix * fix for win64 * fix unitetest * fix * fix * remove comment * move to geometry package * remove redundant include * add docstrings and comments * add proof * add validity check --- CMakeLists.txt | 2 + cmake/modules/CUDA.cmake | 1 + python/dgl/geometry/__init__.py | 12 ++ python/dgl/geometry/capi.py | 37 ++++++ python/dgl/geometry/mxnet/__init__.py | 2 + python/dgl/geometry/mxnet/fps.py | 47 ++++++++ python/dgl/geometry/pytorch/__init__.py | 2 + python/dgl/geometry/pytorch/fps.py | 46 ++++++++ src/geometry/cpu/geometry_op_impl.cc | 100 +++++++++++++++++ src/geometry/cuda/geometry_op_impl.cu | 143 ++++++++++++++++++++++++ src/geometry/geometry.cc | 49 ++++++++ src/geometry/geometry_op.h | 23 ++++ tests/mxnet/test_geometry.py | 22 ++++ tests/pytorch/test_geometry.py | 21 ++++ 14 files changed, 507 insertions(+) create mode 100644 python/dgl/geometry/__init__.py create mode 100644 python/dgl/geometry/capi.py create mode 100644 python/dgl/geometry/mxnet/__init__.py create mode 100644 python/dgl/geometry/mxnet/fps.py create mode 100644 python/dgl/geometry/pytorch/__init__.py create mode 100644 python/dgl/geometry/pytorch/fps.py create mode 100644 src/geometry/cpu/geometry_op_impl.cc create mode 100644 src/geometry/cuda/geometry_op_impl.cu create mode 100644 src/geometry/geometry.cc create mode 100644 src/geometry/geometry_op.h create mode 100644 tests/mxnet/test_geometry.py create mode 100644 tests/pytorch/test_geometry.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 1cac8149e0e4..4222996e51b3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -103,6 +103,8 @@ file(GLOB DGL_SRC src/kernel/*.cc src/kernel/cpu/*.cc src/runtime/*.cc + src/geometry/*.cc + src/geometry/cpu/*.cc ) file(GLOB_RECURSE DGL_SRC_1 diff --git a/cmake/modules/CUDA.cmake b/cmake/modules/CUDA.cmake index 8bd856cc2657..63ba02639fd0 100644 --- a/cmake/modules/CUDA.cmake +++ b/cmake/modules/CUDA.cmake @@ -236,6 +236,7 @@ macro(dgl_config_cuda out_variable) src/kernel/cuda/*.cc src/kernel/cuda/*.cu src/runtime/cuda/*.cc + src/geometry/cuda/*.cu ) dgl_select_nvcc_arch_flags(NVCC_FLAGS_ARCH) diff --git a/python/dgl/geometry/__init__.py b/python/dgl/geometry/__init__.py new file mode 100644 index 000000000000..db790173c5a0 --- /dev/null +++ b/python/dgl/geometry/__init__.py @@ -0,0 +1,12 @@ +"""Package for geometry common components.""" +import importlib +import sys +from ..backend import backend_name + +def _load_backend(mod_name): + mod = importlib.import_module('.%s' % mod_name, __name__) + thismod = sys.modules[__name__] + for api, obj in mod.__dict__.items(): + setattr(thismod, api, obj) + +_load_backend(backend_name) diff --git a/python/dgl/geometry/capi.py b/python/dgl/geometry/capi.py new file mode 100644 index 000000000000..32fd43d08429 --- /dev/null +++ b/python/dgl/geometry/capi.py @@ -0,0 +1,37 @@ +"""Python interfaces to DGL farthest point sampler.""" +from .._ffi.function import _init_api +from .. import backend as F + +def farthest_point_sampler(data, batch_size, sample_points, dist, start_idx, result): + """Farthest Point Sampler + + Parameters + ---------- + data : tensor + A tensor of shape (N, d) where N is the number of points and d is the dimension. + batch_size : int + The number of batches in the ``data``. N should be divisible by batch_size. + sample_points : int + The number of points to sample in each batch. + dist : tensor + Pre-allocated tensor of shape (N, ) for to-sample distance. + start_idx : tensor of int + Pre-allocated tensor of shape (batch_size, ) for the starting sample in each batch. + result : tensor of int + Pre-allocated tensor of shape (sample_points * batch_size, ) for the sampled index. + + Returns + ------- + No return value. The input variable ``result`` will be overwriten with sampled indices. + + """ + assert F.shape(data)[0] >= sample_points * batch_size + assert F.shape(data)[0] % batch_size == 0 + + _CAPI_FarthestPointSampler(F.zerocopy_to_dgl_ndarray(data), + batch_size, sample_points, + F.zerocopy_to_dgl_ndarray(dist), + F.zerocopy_to_dgl_ndarray(start_idx), + F.zerocopy_to_dgl_ndarray(result)) + +_init_api('dgl.geometry', __name__) diff --git a/python/dgl/geometry/mxnet/__init__.py b/python/dgl/geometry/mxnet/__init__.py new file mode 100644 index 000000000000..589bb65252e8 --- /dev/null +++ b/python/dgl/geometry/mxnet/__init__.py @@ -0,0 +1,2 @@ +"""Package for mxnet-specific Geometry modules.""" +from .fps import * diff --git a/python/dgl/geometry/mxnet/fps.py b/python/dgl/geometry/mxnet/fps.py new file mode 100644 index 000000000000..ce4c721a70cd --- /dev/null +++ b/python/dgl/geometry/mxnet/fps.py @@ -0,0 +1,47 @@ +"""Farthest Point Sampler for mxnet Geometry package""" +#pylint: disable=no-member, invalid-name + +from mxnet import nd +from mxnet.gluon import nn +import numpy as np + +from ..capi import farthest_point_sampler + +class FarthestPointSampler(nn.Block): + """Farthest Point Sampler + + In each batch, the algorithm starts with the sample index specified by ``start_idx``. + Then for each point, we maintain the minimum to-sample distance. + Finally, we pick the point with the maximum such distance. + This process will be repeated for ``sample_points`` - 1 times. + + Parameters + ---------- + npoints : int + The number of points to sample in each batch. + """ + def __init__(self, npoints): + super(FarthestPointSampler, self).__init__() + self.npoints = npoints + + def forward(self, pos): + r"""Memory allocation and sampling + + Parameters + ---------- + pos : tensor + The positional tensor of shape (B, N, C) + + Returns + ------- + tensor of shape (B, self.npoints) + The sampled indices in each batch. + """ + ctx = pos.context + B, N, C = pos.shape + pos = pos.reshape(-1, C) + dist = nd.zeros((B * N), dtype=pos.dtype, ctx=ctx) + start_idx = nd.random.randint(0, N - 1, (B, ), dtype=np.int, ctx=ctx) + result = nd.zeros((self.npoints * B), dtype=np.int, ctx=ctx) + farthest_point_sampler(pos, B, self.npoints, dist, start_idx, result) + return result.reshape(B, self.npoints) diff --git a/python/dgl/geometry/pytorch/__init__.py b/python/dgl/geometry/pytorch/__init__.py new file mode 100644 index 000000000000..589bb65252e8 --- /dev/null +++ b/python/dgl/geometry/pytorch/__init__.py @@ -0,0 +1,2 @@ +"""Package for mxnet-specific Geometry modules.""" +from .fps import * diff --git a/python/dgl/geometry/pytorch/fps.py b/python/dgl/geometry/pytorch/fps.py new file mode 100644 index 000000000000..fc14cc1f741f --- /dev/null +++ b/python/dgl/geometry/pytorch/fps.py @@ -0,0 +1,46 @@ +"""Farthest Point Sampler for pytorch Geometry package""" +#pylint: disable=no-member, invalid-name + +import torch as th +from torch import nn + +from ..capi import farthest_point_sampler + +class FarthestPointSampler(nn.Module): + """Farthest Point Sampler without the need to compute all pairs of distance. + + In each batch, the algorithm starts with the sample index specified by ``start_idx``. + Then for each point, we maintain the minimum to-sample distance. + Finally, we pick the point with the maximum such distance. + This process will be repeated for ``sample_points`` - 1 times. + + Parameters + ---------- + npoints : int + The number of points to sample in each batch. + """ + def __init__(self, npoints): + super(FarthestPointSampler, self).__init__() + self.npoints = npoints + + def forward(self, pos): + r"""Memory allocation and sampling + + Parameters + ---------- + pos : tensor + The positional tensor of shape (B, N, C) + + Returns + ------- + tensor of shape (B, self.npoints) + The sampled indices in each batch. + """ + device = pos.device + B, N, C = pos.shape + pos = pos.reshape(-1, C) + dist = th.zeros((B * N), dtype=pos.dtype, device=device) + start_idx = th.randint(0, N - 1, (B, ), dtype=th.int, device=device) + result = th.zeros((self.npoints * B), dtype=th.int, device=device) + farthest_point_sampler(pos, B, self.npoints, dist, start_idx, result) + return result.reshape(B, self.npoints) diff --git a/src/geometry/cpu/geometry_op_impl.cc b/src/geometry/cpu/geometry_op_impl.cc new file mode 100644 index 000000000000..f8a2aa3d3339 --- /dev/null +++ b/src/geometry/cpu/geometry_op_impl.cc @@ -0,0 +1,100 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file array/cpu/geometry_op_impl.cc + * \brief Geometry operator CPU implementation + */ +#include +#include +#include + +namespace dgl { +using runtime::NDArray; +namespace geometry { +namespace impl { + +/*! + * \brief Farthest Point Sampler without the need to compute all pairs of distance. + * + * The input array has shape (N, d), where N is the number of points, and d is the dimension. + * It consists of a (flatten) batch of point clouds. + * + * In each batch, the algorithm starts with the sample index specified by ``start_idx``. + * Then for each point, we maintain the minimum to-sample distance. + * Finally, we pick the point with the maximum such distance. + * This process will be repeated for ``sample_points`` - 1 times. + */ +template +void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_points, + NDArray dist, IdArray start_idx, IdArray result) { + const FloatType* array_data = static_cast(array->data); + const int64_t point_in_batch = array->shape[0] / batch_size; + const int64_t dim = array->shape[1]; + + // distance + FloatType* dist_data = static_cast(dist->data); + + // sample for each cloud in the batch + IdType* start_idx_data = static_cast(start_idx->data); + + // return value + IdType* ret_data = static_cast(result->data); + + int64_t array_start = 0, ret_start = 0; + // loop for each point cloud sample in this batch + for (auto b = 0; b < batch_size; b++) { + // random init start sample + int64_t sample_idx = (int64_t)start_idx_data[b]; + ret_data[ret_start] = (IdType)(sample_idx); + + // sample the rest `sample_points - 1` points + for (auto i = 0; i < sample_points - 1; i++) { + // re-init distance and the argmax + int64_t dist_argmax = 0; + FloatType dist_max = -1; + + // update the distance + for (auto j = 0; j < point_in_batch; j++) { + // compute the distance on dimensions + FloatType one_dist = 0; + for (auto d = 0; d < dim; d++) { + FloatType tmp = array_data[(array_start + j) * dim + d] - + array_data[(array_start + sample_idx) * dim + d]; + one_dist += tmp * tmp; + } + + // for each out-of-set point, keep its nearest to-the-set distance + if (i == 0 || dist_data[j] > one_dist) { + dist_data[j] = one_dist; + } + // look for the farthest sample + if (dist_data[j] > dist_max) { + dist_argmax = j; + dist_max = dist_data[j]; + } + } + // sample the `dist_argmax`-th point + sample_idx = dist_argmax; + ret_data[ret_start + i + 1] = (IdType)(sample_idx); + } + + array_start += point_in_batch; + ret_start += sample_points; + } +} + +template void FarthestPointSampler( + NDArray array, int64_t batch_size, int64_t sample_points, + NDArray dist, IdArray start_idx, IdArray result); +template void FarthestPointSampler( + NDArray array, int64_t batch_size, int64_t sample_points, + NDArray dist, IdArray start_idx, IdArray result); +template void FarthestPointSampler( + NDArray array, int64_t batch_size, int64_t sample_points, + NDArray dist, IdArray start_idx, IdArray result); +template void FarthestPointSampler( + NDArray array, int64_t batch_size, int64_t sample_points, + NDArray dist, IdArray start_idx, IdArray result); + +} // namespace impl +} // namespace geometry +} // namespace dgl diff --git a/src/geometry/cuda/geometry_op_impl.cu b/src/geometry/cuda/geometry_op_impl.cu new file mode 100644 index 000000000000..dc771ccbbcfe --- /dev/null +++ b/src/geometry/cuda/geometry_op_impl.cu @@ -0,0 +1,143 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file geometry/cuda/geometry_op_impl.cc + * \brief Geometry operator CUDA implementation + */ +#include + +#include "../../runtime/cuda/cuda_common.h" +#include "../../c_api_common.h" +#include "../geometry_op.h" + +#define THREADS 1024 + +namespace dgl { +namespace geometry { +namespace impl { + +/*! + * \brief Farthest Point Sampler without the need to compute all pairs of distance. + * + * The input array has shape (N, d), where N is the number of points, and d is the dimension. + * It consists of a (flatten) batch of point clouds. + * + * In each batch, the algorithm starts with the sample index specified by ``start_idx``. + * Then for each point, we maintain the minimum to-sample distance. + * Finally, we pick the point with the maximum such distance. + * This process will be repeated for ``sample_points`` - 1 times. + */ +template +__global__ void fps_kernel(const FloatType *array_data, const int64_t batch_size, + const int64_t sample_points, const int64_t point_in_batch, + const int64_t dim, const IdType *start_idx, + FloatType *dist_data, IdType *ret_data) { + const int64_t thread_idx = threadIdx.x; + const int64_t batch_idx = blockIdx.x; + + const int64_t array_start = point_in_batch * batch_idx; + const int64_t ret_start = sample_points * batch_idx; + + __shared__ FloatType dist_max_ht[THREADS]; + __shared__ int64_t dist_argmax_ht[THREADS]; + + // start with random initialization + if (thread_idx == 0) { + ret_data[ret_start] = (IdType)(start_idx[batch_idx]); + } + + // sample the rest `sample_points - 1` points + for (auto i = 0; i < sample_points - 1; i++) { + __syncthreads(); + + // the last sampled point + int64_t sample_idx = (int64_t)(ret_data[ret_start + i]); + FloatType dist_max = (FloatType)(-1.); + int64_t dist_argmax = 0; + + // multi-thread distance calculation + for (auto j = thread_idx; j < point_in_batch; j += THREADS) { + FloatType one_dist = (FloatType)(0.); + for (auto d = 0; d < dim; d++) { + FloatType tmp = array_data[(array_start + j) * dim + d] - + array_data[(array_start + sample_idx) * dim + d]; + one_dist += tmp * tmp; + } + + if (i == 0 || dist_data[array_start + j] > one_dist) { + dist_data[array_start + j] = one_dist; + } + + if (dist_data[array_start + j] > dist_max) { + dist_argmax = j; + dist_max = dist_data[array_start + j]; + } + } + + dist_max_ht[thread_idx] = dist_max; + dist_argmax_ht[thread_idx] = dist_argmax; + + /* + * \brief Parallel Reduction + * + * Suppose the maximum is dist_max_ht[k], where 0 <= k < THREAD. + * After loop at j = 1, the maximum is propagated to [k-1]. + * After loop at j = 2, the maximum is propagated to the range [k-3] to [k]. + * After loop at j = 4, the maximum is propagated to the range [k-7] to [k]. + * After loop at any j < THREADS, we can see [k - 2*j + 1] to [k] are all covered by the maximum. + * The max value of j is at least floor(THREAD / 2), and it is sufficient to cover [0] with the maximum. + */ + + for (auto j = 1; j < THREADS; j *= 2) { + __syncthreads(); + if ((thread_idx + j) < THREADS && dist_max_ht[thread_idx] < dist_max_ht[thread_idx + j]) { + dist_max_ht[thread_idx] = dist_max_ht[thread_idx + j]; + dist_argmax_ht[thread_idx] = dist_argmax_ht[thread_idx + j]; + } + } + + if (thread_idx == 0) { + ret_data[ret_start + i + 1] = (IdType)(dist_argmax_ht[0]); + } + } +} + +template +void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_points, + NDArray dist, IdArray start_idx, IdArray result) { + auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); + + const FloatType* array_data = static_cast(array->data); + + const int64_t point_in_batch = array->shape[0] / batch_size; + const int64_t dim = array->shape[1]; + + // return value + IdType* ret_data = static_cast(result->data); + + // distance + FloatType* dist_data = static_cast(dist->data); + + // sample for each cloud in the batch + IdType* start_idx_data = static_cast(start_idx->data); + + fps_kernel<<stream>>>( + array_data, batch_size, sample_points, + point_in_batch, dim, start_idx_data, dist_data, ret_data); +} + +template void FarthestPointSampler( + NDArray array, int64_t batch_size, int64_t sample_points, + NDArray dist, IdArray start_idx, IdArray result); +template void FarthestPointSampler( + NDArray array, int64_t batch_size, int64_t sample_points, + NDArray dist, IdArray start_idx, IdArray result); +template void FarthestPointSampler( + NDArray array, int64_t batch_size, int64_t sample_points, + NDArray dist, IdArray start_idx, IdArray result); +template void FarthestPointSampler( + NDArray array, int64_t batch_size, int64_t sample_points, + NDArray dist, IdArray start_idx, IdArray result); + +} // namespace impl +} // namespace geometry +} // namespace dgl diff --git a/src/geometry/geometry.cc b/src/geometry/geometry.cc new file mode 100644 index 000000000000..5f24eba6005c --- /dev/null +++ b/src/geometry/geometry.cc @@ -0,0 +1,49 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file geometry/geometry.cc + * \brief DGL geometry utilities implementation + */ +#include +#include +#include "../c_api_common.h" +#include "./geometry_op.h" + +using namespace dgl::runtime; + +namespace dgl { +namespace geometry { + +void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_points, + NDArray dist, IdArray start_idx, IdArray result) { + + CHECK_EQ(array->ctx, result->ctx) << "Array and the result should be on the same device."; + CHECK_EQ(array->shape[0], dist->shape[0]) << "Shape of array and dist mismatch"; + CHECK_EQ(start_idx->shape[0], batch_size) << "Shape of start_idx and batch_size mismatch"; + CHECK_EQ(result->shape[0], batch_size * sample_points) << "Invalid shape of result"; + + ATEN_FLOAT_TYPE_SWITCH(array->dtype, FloatType, "values", { + ATEN_ID_TYPE_SWITCH(result->dtype, IdType, { + ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "FarthestPointSampler", { + impl::FarthestPointSampler( + array, batch_size, sample_points, dist, start_idx, result); + }); + }); + }); +} + +///////////////////////// C APIs ///////////////////////// + +DGL_REGISTER_GLOBAL("geometry._CAPI_FarthestPointSampler") +.set_body([] (DGLArgs args, DGLRetValue* rv) { + const NDArray data = args[0]; + const int64_t batch_size = args[1]; + const int64_t sample_points = args[2]; + NDArray dist = args[3]; + IdArray start_idx = args[4]; + IdArray result = args[5]; + + FarthestPointSampler(data, batch_size, sample_points, dist, start_idx, result); + }); + +} // namespace geometry +} // namespace dgl diff --git a/src/geometry/geometry_op.h b/src/geometry/geometry_op.h new file mode 100644 index 000000000000..d23ffdfcd63b --- /dev/null +++ b/src/geometry/geometry_op.h @@ -0,0 +1,23 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file geometry/geometry_op.h + * \brief Geometry operator templates + */ +#ifndef DGL_GEOMETRY_GEOMETRY_OP_H_ +#define DGL_GEOMETRY_GEOMETRY_OP_H_ + +#include + +namespace dgl { +namespace geometry { +namespace impl { + +template +void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_points, + NDArray dist, IdArray start_idx, IdArray result); + +} // namespace impl +} // namespace geometry +} // namespace dgl + +#endif // DGL_GEOMETRY_GEOMETRY_OP_H_ diff --git a/tests/mxnet/test_geometry.py b/tests/mxnet/test_geometry.py new file mode 100644 index 000000000000..68500c62ad70 --- /dev/null +++ b/tests/mxnet/test_geometry.py @@ -0,0 +1,22 @@ +import mxnet as mx +from dgl.geometry.mxnet import FarthestPointSampler +import backend as F + +import numpy as np + +def test_fps(): + N = 1000 + batch_size = 5 + sample_points = 10 + x = mx.nd.array(np.random.uniform(size=(batch_size, int(N/batch_size), 3))) + ctx = F.ctx() + if F.gpu_ctx(): + x = x.as_in_context(ctx) + fps = FarthestPointSampler(sample_points) + res = fps(x) + assert res.shape[0] == batch_size + assert res.shape[1] == sample_points + assert res.sum() > 0 + +if __name__ == '__main__': + test_fps() diff --git a/tests/pytorch/test_geometry.py b/tests/pytorch/test_geometry.py new file mode 100644 index 000000000000..11810a44172e --- /dev/null +++ b/tests/pytorch/test_geometry.py @@ -0,0 +1,21 @@ +import torch as th +from dgl.geometry.pytorch import FarthestPointSampler +import backend as F +import numpy as np + +def test_fps(): + N = 1000 + batch_size = 5 + sample_points = 10 + x = th.tensor(np.random.uniform(size=(batch_size, int(N/batch_size), 3))) + ctx = F.ctx() + if F.gpu_ctx(): + x = x.to(ctx) + fps = FarthestPointSampler(sample_points) + res = fps(x) + assert res.shape[0] == batch_size + assert res.shape[1] == sample_points + assert res.sum() > 0 + +if __name__ == '__main__': + test_fps()