Skip to content

Commit

Permalink
Merge branch 'main' of github.com:mihdalal/pointnet2_ops into v3.2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
mihdalal committed Apr 26, 2024
2 parents 62fa68b + b98726c commit b1381c5
Show file tree
Hide file tree
Showing 8 changed files with 234 additions and 251 deletions.
28 changes: 28 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,31 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


----------------------------------------------------------------------------
Much of the CUDA kernel code comes from https://github.com/sshaoshuai/Pointnet2.PyTorch
Here is the original license for that code

MIT License

Copyright (c) 2019 Shaoshuai Shi

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in al
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

42 changes: 21 additions & 21 deletions pointnet2_ops/_ext-src/src/ball_query_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,35 +17,35 @@ __global__ void query_ball_point_kernel(
const scalar_t *__restrict__ new_xyz,
const scalar_t *__restrict__ xyz,
int *__restrict__ idx) {
int batch_index = blockIdx.x;
xyz += batch_index * n * 3;
new_xyz += batch_index * m * 3;
idx += m * nsample * batch_index;
int bs_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= b || pt_idx >= m) return;

int index = threadIdx.x;
int stride = blockDim.x;
new_xyz += bs_idx * m * 3 + pt_idx * 3;
xyz += bs_idx * n * 3;
idx += bs_idx * m * nsample + pt_idx * nsample;

float radius2 = radius * radius;
for (int j = index; j < m; j += stride) {
scalar_t new_x = new_xyz[j * 3 + 0];
scalar_t new_y = new_xyz[j * 3 + 1];
scalar_t new_z = new_xyz[j * 3 + 2];
for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) {
scalar_t new_x = new_xyz[0];
scalar_t new_y = new_xyz[1];
scalar_t new_z = new_xyz[2];

int cnt = 0;
for (int k = 0; k < n; ++k) {
scalar_t x = xyz[k * 3 + 0];
scalar_t y = xyz[k * 3 + 1];
scalar_t z = xyz[k * 3 + 2];
scalar_t d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
(new_z - z) * (new_z - z);
if (d2 < radius2) {
if (cnt == 0) {
for (int l = 0; l < nsample; ++l) {
idx[j * nsample + l] = k;
scalar_t d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z);
if (d2 < radius2){
if (cnt == 0){
for (int l = 0; l < nsample; ++l) {
idx[l] = k;
}
}
}
idx[j * nsample + cnt] = k;
++cnt;
idx[cnt] = k;
++cnt;
if (cnt >= nsample) break;
}
}
}
}

Expand Down
65 changes: 31 additions & 34 deletions pointnet2_ops/_ext-src/src/group_points_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,28 @@
// input: points(b, c, n) idx(b, npoints, nsample)
// output: out(b, c, npoints, nsample)
template<typename scalar_t>
__global__ void group_points_kernel(int b, int c, int n, int npoints,
int nsample,
const scalar_t *__restrict__ points,
const int *__restrict__ idx,
scalar_t *__restrict__ out) {
int batch_index = blockIdx.x;
points += batch_index * n * c;
idx += batch_index * npoints * nsample;
out += batch_index * npoints * nsample * c;
__global__ void group_points_kernel(
int b,
int c,
int n,
int npoints,
int nsample,
const scalar_t *__restrict__ points,
const int *__restrict__ idx,
scalar_t *__restrict__ out) {
int bs_idx = blockIdx.z;
int c_idx = blockIdx.y;
int index = blockIdx.x * blockDim.x + threadIdx.x;
int pt_idx = index / nsample;
if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;

int sample_idx = index % nsample;

const int index = threadIdx.y * blockDim.x + threadIdx.x;
const int stride = blockDim.y * blockDim.x;
for (int i = index; i < c * npoints; i += stride) {
const int l = i / npoints;
const int j = i % npoints;
for (int k = 0; k < nsample; ++k) {
int ii = idx[j * nsample + k];
out[(l * npoints + j) * nsample + k] = points[l * n + ii];
}
}
idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx;
int in_idx = bs_idx * c * n + c_idx * n + idx[0];
int out_idx = bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx;

out[out_idx] = points[in_idx];
}

at::Tensor group_points_kernel_wrapper(
Expand Down Expand Up @@ -70,22 +72,17 @@ __global__ void group_points_grad_kernel(
const scalar_t *__restrict__ grad_out,
const int *__restrict__ idx,
scalar_t *__restrict__ grad_points) {
int batch_index = blockIdx.x;
grad_out += batch_index * npoints * nsample * c;
idx += batch_index * npoints * nsample;
grad_points += batch_index * n * c;
int bs_idx = blockIdx.z;
int c_idx = blockIdx.y;
int index = blockIdx.x * blockDim.x + threadIdx.x;
int pt_idx = index / nsample;
if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;

int sample_idx = index % nsample;
grad_out += bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx;
idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx;

const int index = threadIdx.y * blockDim.x + threadIdx.x;
const int stride = blockDim.y * blockDim.x;
for (int i = index; i < c * npoints; i += stride) {
const int l = i / npoints;
const int j = i % npoints;
for (int k = 0; k < nsample; ++k) {
int ii = idx[j * nsample + k];
atomicAdd(grad_points + l * n + ii,
grad_out[(l * npoints + j) * nsample + k]);
}
}
atomicAdd(grad_points + bs_idx * c * n + c_idx * n + idx[0] , grad_out[0]);
}

at::Tensor group_points_grad_kernel_wrapper(
Expand Down
147 changes: 60 additions & 87 deletions pointnet2_ops/_ext-src/src/interpolate_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,56 +9,49 @@
// input: unknown(b, n, 3) known(b, m, 3)
// output: dist2(b, n, 3), idx(b, n, 3)
template<typename scalar_t>
__global__ void three_nn_kernel(int b, int n, int m,
const scalar_t *__restrict__ unknown,
const scalar_t *__restrict__ known,
scalar_t *__restrict__ dist2,
int *__restrict__ idx) {
int batch_index = blockIdx.x;
unknown += batch_index * n * 3;
known += batch_index * m * 3;
dist2 += batch_index * n * 3;
idx += batch_index * n * 3;

int index = threadIdx.x;
int stride = blockDim.x;
for (int j = index; j < n; j += stride) {
scalar_t ux = unknown[j * 3 + 0];
scalar_t uy = unknown[j * 3 + 1];
scalar_t uz = unknown[j * 3 + 2];

double best1 = 1e40, best2 = 1e40, best3 = 1e40;
int besti1 = 0, besti2 = 0, besti3 = 0;
for (int k = 0; k < m; ++k) {
__global__ void three_nn_kernel(
int b,
int n,
int m,
const scalar_t *__restrict__ unknown,
const scalar_t *__restrict__ known,
scalar_t *__restrict__ dist2,
int *__restrict__ idx) {
int bs_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= b || pt_idx >= n) return;

unknown += bs_idx * n * 3 + pt_idx * 3;
known += bs_idx * m * 3;
dist2 += bs_idx * n * 3 + pt_idx * 3;
idx += bs_idx * n * 3 + pt_idx * 3;

scalar_t ux = unknown[0];
scalar_t uy = unknown[1];
scalar_t uz = unknown[2];

double best1 = 1e40, best2 = 1e40, best3 = 1e40;
int besti1 = 0, besti2 = 0, besti3 = 0;
for (int k = 0; k < m; ++k) {
scalar_t x = known[k * 3 + 0];
scalar_t y = known[k * 3 + 1];
scalar_t z = known[k * 3 + 2];
scalar_t d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
if (d < best1) {
best3 = best2;
besti3 = besti2;
best2 = best1;
besti2 = besti1;
best1 = d;
besti1 = k;
} else if (d < best2) {
best3 = best2;
besti3 = besti2;
best2 = d;
besti2 = k;
} else if (d < best3) {
best3 = d;
besti3 = k;
best3 = best2; besti3 = besti2;
best2 = best1; besti2 = besti1;
best1 = d; besti1 = k;
}
else if (d < best2) {
best3 = best2; besti3 = besti2;
best2 = d; besti2 = k;
}
else if (d < best3) {
best3 = d; besti3 = k;
}
}
dist2[j * 3 + 0] = best1;
dist2[j * 3 + 1] = best2;
dist2[j * 3 + 2] = best3;

idx[j * 3 + 0] = besti1;
idx[j * 3 + 1] = besti2;
idx[j * 3 + 2] = besti3;
}
dist2[0] = best1; dist2[1] = best2; dist2[2] = best3;
idx[0] = besti1; idx[1] = besti2; idx[2] = besti3;
}

std::vector<at::Tensor> three_nn_kernel_wrapper(
Expand Down Expand Up @@ -103,30 +96,18 @@ __global__ void three_interpolate_kernel(
const int *__restrict__ idx,
const scalar_t *__restrict__ weight,
scalar_t *__restrict__ out) {
int batch_index = blockIdx.x;
points += batch_index * m * c;

idx += batch_index * n * 3;
weight += batch_index * n * 3;

out += batch_index * n * c;
int bs_idx = blockIdx.z;
int c_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;

const int index = threadIdx.y * blockDim.x + threadIdx.x;
const int stride = blockDim.y * blockDim.x;
for (int i = index; i < c * n; i += stride) {
const int l = i / n;
const int j = i % n;
scalar_t w1 = weight[j * 3 + 0];
scalar_t w2 = weight[j * 3 + 1];
scalar_t w3 = weight[j * 3 + 2];
if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;

int i1 = idx[j * 3 + 0];
int i2 = idx[j * 3 + 1];
int i3 = idx[j * 3 + 2];
weight += bs_idx * n * 3 + pt_idx * 3;
points += bs_idx * c * m + c_idx * m;
idx += bs_idx * n * 3 + pt_idx * 3;
out += bs_idx * c * n + c_idx * n;

out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 +
points[l * m + i3] * w3;
}
out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] + weight[2] * points[idx[2]];
}

at::Tensor three_interpolate_kernel_wrapper(
Expand Down Expand Up @@ -172,29 +153,21 @@ __global__ void three_interpolate_grad_kernel(
const int *__restrict__ idx,
const scalar_t *__restrict__ weight,
scalar_t *__restrict__ grad_points) {
int batch_index = blockIdx.x;
grad_out += batch_index * n * c;
idx += batch_index * n * 3;
weight += batch_index * n * 3;
grad_points += batch_index * m * c;

const int index = threadIdx.y * blockDim.x + threadIdx.x;
const int stride = blockDim.y * blockDim.x;
for (int i = index; i < c * n; i += stride) {
const int l = i / n;
const int j = i % n;
scalar_t w1 = weight[j * 3 + 0];
scalar_t w2 = weight[j * 3 + 1];
scalar_t w3 = weight[j * 3 + 2];

int i1 = idx[j * 3 + 0];
int i2 = idx[j * 3 + 1];
int i3 = idx[j * 3 + 2];

atomicAdd(grad_points + l * m + i1, grad_out[i] * w1);
atomicAdd(grad_points + l * m + i2, grad_out[i] * w2);
atomicAdd(grad_points + l * m + i3, grad_out[i] * w3);
}
int bs_idx = blockIdx.z;
int c_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;

if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;

grad_out += bs_idx * c * n + c_idx * n + pt_idx;
weight += bs_idx * n * 3 + pt_idx * 3;
grad_points += bs_idx * c * m + c_idx * m;
idx += bs_idx * n * 3 + pt_idx * 3;


atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]);
atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]);
atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]);
}

at::Tensor three_interpolate_grad_kernel_wrapper(
Expand Down
Loading

0 comments on commit b1381c5

Please sign in to comment.