diff --git a/python/dgl/geometry/fps.py b/python/dgl/geometry/fps.py index 99d8af5f8206..79e4f4c44bd9 100644 --- a/python/dgl/geometry/fps.py +++ b/python/dgl/geometry/fps.py @@ -8,6 +8,7 @@ __all__ = ['farthest_point_sampler'] + def farthest_point_sampler(pos, npoints, start_idx=None): """Farthest Point Sampler without the need to compute all pairs of distance. @@ -49,12 +50,13 @@ def farthest_point_sampler(pos, npoints, start_idx=None): pos = pos.reshape(-1, C) dist = F.zeros((B * N), dtype=pos.dtype, ctx=ctx) if start_idx is None: - start_idx = F.randint(shape=(B, ), dtype=F.int64, ctx=ctx, low=0, high=N-1) + start_idx = F.randint(shape=(B, ), dtype=F.int64, + ctx=ctx, low=0, high=N-1) else: if start_idx >= N or start_idx < 0: raise DGLError("Invalid start_idx, expected 0 <= start_idx < {}, got {}".format( N, start_idx)) - start_idx = F.full_1d((B, ), start_idx, dtype=F.int64, ctx=ctx) + start_idx = F.full_1d(B, start_idx, dtype=F.int64, ctx=ctx) result = F.zeros((npoints * B), dtype=F.int64, ctx=ctx) _farthest_point_sampler(pos, B, npoints, dist, start_idx, result) return result.reshape(B, npoints) diff --git a/src/geometry/cuda/geometry_op_impl.cu b/src/geometry/cuda/geometry_op_impl.cu index 5bebe892d303..a4b877f14782 100644 --- a/src/geometry/cuda/geometry_op_impl.cu +++ b/src/geometry/cuda/geometry_op_impl.cu @@ -107,6 +107,7 @@ void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_poin // sample for each cloud in the batch IdType* start_idx_data = static_cast(start_idx->data); + CUDA_CALL(cudaSetDevice(array->ctx.device_id)); CUDA_KERNEL_CALL(fps_kernel, batch_size, THREADS, 0, thr_entry->stream, diff --git a/tests/pytorch/test_geometry.py b/tests/pytorch/test_geometry.py index e3a40e555989..efe02f986f2d 100644 --- a/tests/pytorch/test_geometry.py +++ b/tests/pytorch/test_geometry.py @@ -25,6 +25,18 @@ def test_fps(): assert res.sum() > 0 +def test_fps_start_idx(): + N = 1000 + batch_size = 5 + sample_points = 10 + x = th.tensor(np.random.uniform(size=(batch_size, int(N/batch_size), 3))) + ctx = F.ctx() + if F.gpu_ctx(): + x = x.to(ctx) + res = farthest_point_sampler(x, sample_points, start_idx=0) + assert th.any(res[:, 0] == 0) + + @pytest.mark.parametrize('algorithm', ['bruteforce-blas', 'bruteforce', 'kd-tree']) @pytest.mark.parametrize('dist', ['euclidean', 'cosine']) def test_knn_cpu(algorithm, dist): @@ -208,4 +220,5 @@ def test_edge_coarsening(idtype, g, weight, relabel): if __name__ == '__main__': test_fps() + test_fps_start_idx() test_knn()