Skip to content

Commit

Permalink
[UPDATE] simplify backward
Browse files Browse the repository at this point in the history
  • Loading branch information
wrc042 committed Jan 10, 2023
1 parent a85d201 commit dd6501b
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 179 deletions.
11 changes: 6 additions & 5 deletions tests/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import torch
from time import time

os.environ["CUDA_VISIBLE_DIVICES"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = "cuda"
# Ns
num_sample = 100000
num_sample = 1000000
samples = torch.rand((num_sample, 3)).to(device).detach()
samples = samples * 2 - 1

all_pass = True

Expand All @@ -28,14 +29,14 @@

# TorchSDF
# (Ns)
distances, normal, face_indexes, types = compute_sdf(x, face_verts)
distances, normals, clst_points = compute_sdf(x, face_verts)
gradient = torch.autograd.grad([distances.sum()], [x], create_graph=True,
retain_graph=True)[0]

normal_direct = normal * 2 * distances.unsqueeze(1).sqrt()
normal_direct = normals * 2 * distances.unsqueeze(1).sqrt()
normal_from_grad = torch.autograd.grad([distances.sum()], [x], create_graph=True,
retain_graph=True)[0]
normal_fit = torch.allclose(normal_direct, normal_from_grad, atol=5e-7)
normal_fit = torch.allclose(normal_direct, normal_from_grad, atol=8e-7)
if normal_fit:
print("\x1B[32mPass\x1B[0m")
else:
Expand Down
7 changes: 4 additions & 3 deletions tests/value.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
import torch
from time import time

os.environ["CUDA_VISIBLE_DIVICES"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = "cuda"
# Ns
num_sample = 1000000
num_sample = 10000000
samples = torch.rand((num_sample, 3)).to(device).detach()
samples = samples * 2 - 1

all_pass = True

Expand Down Expand Up @@ -45,7 +46,7 @@
# (Ns)
torch.cuda.synchronize()
tmp = time()
distances_ts, normal_ts, face_indexes_ts, types_ts = compute_sdf(x, face_verts_ts)
distances_ts, normals_ts, clst_points_ts = compute_sdf(x, face_verts_ts)
gradient_ts = torch.autograd.grad([distances_ts.sum()], [x], create_graph=True,
retain_graph=True)[0]
torch.cuda.synchronize()
Expand Down
48 changes: 17 additions & 31 deletions torchsdf/csrc/unbatched_triangle_distance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,13 @@ void unbatched_triangle_distance_forward_cuda_impl(
at::Tensor points,
at::Tensor face_vertices,
at::Tensor dist,
at::Tensor normal,
at::Tensor face_idx,
at::Tensor dist_type);
at::Tensor normals,
at::Tensor clst_points);

void unbatched_triangle_distance_backward_cuda_impl(
at::Tensor grad_dist,
at::Tensor points,
at::Tensor face_vertices,
at::Tensor face_idx,
at::Tensor dist_type,
at::Tensor clst_points,
at::Tensor grad_points);

#endif // WITH_CUDA
Expand All @@ -44,29 +41,28 @@ void unbatched_triangle_distance_forward_cuda(
at::Tensor points,
at::Tensor face_vertices,
at::Tensor dist,
at::Tensor normal,
at::Tensor face_idx,
at::Tensor dist_type) {
at::Tensor normals,
at::Tensor clst_points) {
CHECK_CUDA(points);
CHECK_CUDA(face_vertices);
CHECK_CUDA(dist);
CHECK_CUDA(face_idx);
CHECK_CUDA(dist_type);
CHECK_CUDA(normals);
CHECK_CUDA(clst_points);
CHECK_CONTIGUOUS(points);
CHECK_CONTIGUOUS(face_vertices);
CHECK_CONTIGUOUS(dist);
CHECK_CONTIGUOUS(face_idx);
CHECK_CONTIGUOUS(dist_type);
CHECK_CONTIGUOUS(normals);
CHECK_CONTIGUOUS(clst_points);
const int num_points = points.size(0);
const int num_faces = face_vertices.size(0);
CHECK_SIZES(points, num_points, 3);
CHECK_SIZES(face_vertices, num_faces, 3, 3);
CHECK_SIZES(dist, num_points);
CHECK_SIZES(face_idx, num_points);
CHECK_SIZES(dist_type, num_points);
CHECK_SIZES(normals, num_points, 3);
CHECK_SIZES(clst_points, num_points, 3);
#if WITH_CUDA
unbatched_triangle_distance_forward_cuda_impl(
points, face_vertices, dist, normal, face_idx, dist_type);
points, face_vertices, dist, normals, clst_points);
#else
AT_ERROR("unbatched_triangle_distance not built with CUDA");
#endif
Expand All @@ -75,36 +71,26 @@ void unbatched_triangle_distance_forward_cuda(
void unbatched_triangle_distance_backward_cuda(
at::Tensor grad_dist,
at::Tensor points,
at::Tensor face_vertices,
at::Tensor face_idx,
at::Tensor dist_type,
at::Tensor clst_points,
at::Tensor grad_points) {
CHECK_CUDA(grad_dist);
CHECK_CUDA(points);
CHECK_CUDA(face_vertices);
CHECK_CUDA(face_idx);
CHECK_CUDA(dist_type);
CHECK_CUDA(clst_points);
CHECK_CUDA(grad_points);
CHECK_CONTIGUOUS(grad_dist);
CHECK_CONTIGUOUS(points);
CHECK_CONTIGUOUS(face_vertices);
CHECK_CONTIGUOUS(face_idx);
CHECK_CONTIGUOUS(dist_type);
CHECK_CONTIGUOUS(clst_points);
CHECK_CONTIGUOUS(grad_points);

const int num_points = points.size(0);
const int num_faces = face_vertices.size(0);
CHECK_SIZES(grad_dist, num_points);
CHECK_SIZES(points, num_points, 3);
CHECK_SIZES(face_vertices, num_faces, 3, 3);
CHECK_SIZES(face_idx, num_points);
CHECK_SIZES(dist_type, num_points);
CHECK_SIZES(clst_points, num_points, 3);
CHECK_SIZES(grad_points, num_points, 3);

#if WITH_CUDA
unbatched_triangle_distance_backward_cuda_impl(
grad_dist, points, face_vertices, face_idx, dist_type,
grad_points);
grad_dist, points, clst_points, grad_points);
#else
AT_ERROR("unbatched_triangle_distance_backward not built with CUDA");
#endif
Expand Down
9 changes: 3 additions & 6 deletions torchsdf/csrc/unbatched_triangle_distance.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,13 @@ void unbatched_triangle_distance_forward_cuda(
at::Tensor points,
at::Tensor face_vertices,
at::Tensor dist,
at::Tensor normal,
at::Tensor face_idx,
at::Tensor dist_type);
at::Tensor normals,
at::Tensor clst_points);

void unbatched_triangle_distance_backward_cuda(
at::Tensor grad_dist,
at::Tensor points,
at::Tensor face_vertices,
at::Tensor face_idx,
at::Tensor dist_type,
at::Tensor clst_points,
at::Tensor grad_points);

} // namespace kaolin
Expand Down
Loading

0 comments on commit dd6501b

Please sign in to comment.