Skip to content

Commit

Permalink
b5
Browse files Browse the repository at this point in the history
  • Loading branch information
liu-yangs committed May 21, 2024
1 parent 1ac08b1 commit ea712f4
Show file tree
Hide file tree
Showing 12 changed files with 107 additions and 129 deletions.
31 changes: 15 additions & 16 deletions backward_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ def backward(rots, scales, shs, alphas, pws, Rcw, tcw, fx, fy, cx, cy, image_gt,
colors = np.zeros([gs_num, 3])
us = np.zeros([gs_num, 2])
pcs = np.zeros([gs_num, 3])
cinv2ds = np.zeros([gs_num, 3])
cov3ds = np.zeros([gs_num, 6])
cov2ds = np.zeros([gs_num, 3])
dpc_dpws = np.zeros([gs_num, 3, 3])
Expand All @@ -549,6 +550,7 @@ def backward(rots, scales, shs, alphas, pws, Rcw, tcw, fx, fy, cx, cy, image_gt,
dcov2d_dpcs = np.zeros([gs_num, 3, 3])
dcolor_dshs = np.zeros([gs_num, 3, gs['sh'].shape[1]])
dcolor_dpws = np.zeros([gs_num, 3, 3])
dcinv2d_dcov2ds = np.zeros([gs_num, 3, 3])
for i in range(gs_num):
# step1. Transform pw to camera frame,
# and project it to iamge.
Expand All @@ -564,7 +566,7 @@ def backward(rots, scales, shs, alphas, pws, Rcw, tcw, fx, fy, cx, cy, image_gt,
print("%s check du%d_dpc%d" %
(check(du_dpc_numerical, du_dpcs[i]), i, i))

# step2. Calcuate the 3d Gaussian.
# step2. Calcuate the 3d covariance.
cov3ds[i], dcov3d_drots[i], dcov3d_dscales[i] = compute_cov_3d(
gs['rot'][i], gs['scale'][i], True)
dcov3d_dq_numerical = numerical_derivative(
Expand All @@ -576,6 +578,7 @@ def backward(rots, scales, shs, alphas, pws, Rcw, tcw, fx, fy, cx, cy, image_gt,
print("%s check dcov3d%d_ds%d" % (check(
dcov3d_ds_numerical, dcov3d_dscales[i]), i, i))

# step3. Project the 3D Gaussian to 2d image as a 2d covariance.
cov2ds[i], dcov2d_dcov3ds[i], dcov2d_dpcs[i] = compute_cov_2d(
cov3ds[i], pcs[i], Rcw, fx, fy, True)
dcov2d_dcov3d_numerical = numerical_derivative(
Expand All @@ -588,7 +591,7 @@ def backward(rots, scales, shs, alphas, pws, Rcw, tcw, fx, fy, cx, cy, image_gt,
print("%s check dcov2d%d_dpc%d" % (check(
dcov2d_dpc_numerical, dcov2d_dpcs[i]), i, i))

# step3. Project the 3D Gaussian to 2d image as a 2d Gaussian.
# step4. Compute color.
colors[i], dcolor_dshs[i], dcolor_dpws[i] = sh2color(
gs['sh'][i], pws[i], twc, True)
dcolor_dsh_numerical = numerical_derivative(
Expand All @@ -600,6 +603,12 @@ def backward(rots, scales, shs, alphas, pws, Rcw, tcw, fx, fy, cx, cy, image_gt,
print("%s check dcolor%d_dsh%d" % (check(
dcolor_dpw_numerical, dcolor_dpws[i]), i, i))

# step5.1 Compute inverse covariance.
cinv2ds[i], dcinv2d_dcov2ds[i] = calc_cinv2d(cov2ds[i], True)
dcinv2d_dcov2d_numerical = numerical_derivative(calc_cinv2d, [cov2ds[i]], 0)
print("%s check dcinv2d%d_dcov2d%d" % (check(
dcinv2d_dcov2d_numerical, dcinv2d_dcov2ds[i]), i, i))

# ---------------------------------
idx = np.argsort(pcs[:, 2])
idxb = np.argsort(idx)
Expand All @@ -609,26 +618,16 @@ def backward(rots, scales, shs, alphas, pws, Rcw, tcw, fx, fy, cx, cy, image_gt,
us = us[idx].reshape(-1)
x = np.array([16, 8])

calc_gamma(alphas, cov2ds, colors, us, np.array([16, 8]))

cov2d0 = cov2ds[:3]
cinv2d0, dcov2d_dcinv2d = calc_cinv2d(cov2d0, True)
dcov2d_dcinv2d_numerial = numerical_derivative(calc_cinv2d, [cov2d0], 0)
print("%s check dcov2d_dcinv2d" % check(
dcov2d_dcinv2d_numerial, dcov2d_dcinv2d))

cinv2ds = calc_cinv2d(cov2ds)
alpha0, u0 = alphas[:1], us[:2]
calc_alpha_prime(alpha0, cinv2d0, u0, x)

dalphaprime_dalpha_numerial = numerical_derivative(
calc_alpha_prime, [alpha0, cinv2d0, u0, x], 0)
calc_alpha_prime, [alpha0, cinv2ds[0], u0, x], 0)
dalphaprime_dcinv2d_numerial = numerical_derivative(
calc_alpha_prime, [alpha0, cinv2d0, u0, x], 1)
calc_alpha_prime, [alpha0, cinv2ds[0], u0, x], 1)
dalphaprime_du_numerial = numerical_derivative(
calc_alpha_prime, [alpha0, cinv2d0, u0, x], 2)
calc_alpha_prime, [alpha0, cinv2ds[0], u0, x], 2)
alpha_prime, dalphaprime_dalpha, dalphaprime_dcinv2d, dalphaprime_du = calc_alpha_prime(
alpha0, cinv2d0, u0, x, True)
alpha0, cinv2ds[0], u0, x, True)

print("%s check dalphaprime_dalpha" % check(
dalphaprime_dalpha_numerial, dalphaprime_dalpha))
Expand Down
16 changes: 9 additions & 7 deletions forward_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,25 +60,27 @@

camera = Camera(id=0, width=width, height=height, K=K, Rcw=Rcw, tcw=tcw)

pw = gs['pos']
pws = gs['pos']

# step1. Transform pw to camera frame,
# and project it to iamge.
u, pc = project(pw, camera.Rcw, camera.tcw, camera.K)
us, pcs = project(pws, camera.Rcw, camera.tcw, camera.K)

depth = pc[:, 2]
depths = pc[:, 2]

# step2. Calcuate the 3d Gaussian.
cov3d = compute_cov_3d(gs['scale'], gs['rot'])
cov3ds = compute_cov_3d(gs['scale'], gs['rot'])

# step3. Project the 3D Gaussian to 2d image as a 2d Gaussian.
cov2d = compute_cov_2d(pc, camera.K, cov3d, camera.Rcw)
cov2ds = compute_cov_2d(pcs, camera.K, cov3ds, camera.Rcw)

# step4. get color info
color = sh2color(gs['sh'], pw, camera.cam_center)
colors = sh2color(gs['sh'], pws, camera.cam_center)

# step5. Blend the 2d Gaussian to image
image = splat(camera.height, camera.width, u, cov2d, gs['alpha'], depth, color)
cinv2ds, areas = inverse_cov2d(cov2d)

image = splat(camera.height, camera.width, us, cinv2ds, gs['alpha'], depths, colors, areas)
plt.imshow(image)
# from PIL import Image
# pil_img = Image.fromarray((np.clip(image, 0, 1)*255).astype(np.uint8))
Expand Down
70 changes: 12 additions & 58 deletions forward_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,53 +5,6 @@
from read_ply import *


def sh2color_gpu(sh, pw, twc):
pw = torch.from_numpy(pw).type(torch.float32).to('cuda')
sh = torch.from_numpy(sh).type(torch.float32).to('cuda')
twc = torch.from_numpy(twc).type(torch.float32).to('cuda')
color = pg.sh2Color(sh, pw, twc)[0]
return color.to('cpu').numpy()


def project_gpu(pw, Rcw, tcw, K):
pw = torch.from_numpy(pw).type(torch.float32).to('cuda')
Rcw = torch.from_numpy(Rcw).type(torch.float32).to('cuda')
tcw = torch.from_numpy(tcw).type(torch.float32).to('cuda')
u, pc = pg.project(pw, Rcw, tcw, K[0, 0], K[1, 1], K[0, 2], K[1, 2])
return u.to('cpu').numpy(), pc.to('cpu').numpy()


def compute_cov_3d_gpu(scale, rot):
scale = torch.from_numpy(scale).type(torch.float32).to('cuda')
rot = torch.from_numpy(rot).type(torch.float32).to('cuda')
res = pg.computeCov3D(rot, scale)
return res[0].to('cpu').numpy()


def compute_cov_2d_gpu(pc, K, cov3d, Rcw):
pc = torch.from_numpy(pc).type(torch.float32).to('cuda')
cov3d = torch.from_numpy(cov3d).type(torch.float32).to('cuda')
Rcw = torch.from_numpy(Rcw).type(torch.float32).to('cuda')
focal_x = K[0, 0]
focal_y = K[1, 1]
res = pg.computeCov2D(cov3d, pc, Rcw, focal_x, focal_y)
return res[0].to('cpu').numpy()


def splat_gpu(height, width, u, cov2d, alpha, depth, color):
u = torch.from_numpy(u).type(torch.float32).to('cuda')
cov2d = torch.from_numpy(cov2d).type(torch.float32).to('cuda')
alpha = torch.from_numpy(alpha).type(torch.float32).to('cuda')
depth = torch.from_numpy(depth).type(torch.float32).to('cuda')
color = torch.from_numpy(color).type(torch.float32).to('cuda')
res = pg.forward(height, width, u, cov2d, alpha, depth, color)
res_cpu = []
for r in res:
res_cpu.append(r.to('cpu').numpy())
res_cpu[0] = res_cpu[0].transpose(1, 2, 0)
return res_cpu


if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -111,31 +64,32 @@ def splat_gpu(height, width, u, cov2d, alpha, depth, color):
center_x = width / 2
center_y = height / 2

pw = torch.from_numpy(gs['pos']).type(torch.float32).to('cuda')
rot = torch.from_numpy(gs['rot']).type(torch.float32).to('cuda')
scale = torch.from_numpy(gs['scale']).type(torch.float32).to('cuda')
alpha = torch.from_numpy(gs['alpha']).type(torch.float32).to('cuda')
sh = torch.from_numpy(gs['sh']).type(torch.float32).to('cuda')
pws = torch.from_numpy(gs['pos']).type(torch.float32).to('cuda')
rots = torch.from_numpy(gs['rot']).type(torch.float32).to('cuda')
scales = torch.from_numpy(gs['scale']).type(torch.float32).to('cuda')
alphas = torch.from_numpy(gs['alpha']).type(torch.float32).to('cuda')
shs = torch.from_numpy(gs['sh']).type(torch.float32).to('cuda')
Rcw = torch.from_numpy(Rcw).type(torch.float32).to('cuda')
tcw = torch.from_numpy(tcw).type(torch.float32).to('cuda')
twc = torch.linalg.inv(Rcw)@(-tcw)

# step1. Transform pw to camera frame,
# and project it to iamge.
u, pc = pg.project(pw, Rcw, tcw, focal_x, focal_y, center_x, center_y, False)
depth = pc[:, 2]
us, pcs = pg.project(pws, Rcw, tcw, focal_x, focal_y, center_x, center_y, False)
depths = pcs[:, 2]

# step2. Calcuate the 3d Gaussian.
cov3d = pg.computeCov3D(rot, scale, False)[0]
cov3ds = pg.computeCov3D(rots, scales, False)[0]

# step3. Calcuate the 2d Gaussian.
cov2d = pg.computeCov2D(cov3d, pc, Rcw, focal_x, focal_y, False)[0]
cov2ds = pg.computeCov2D(cov3ds, pcs, Rcw, focal_x, focal_y, False)[0]

# step4. get color info
color = pg.sh2Color(sh, pw, twc, False)[0]
colors = pg.sh2Color(shs, pws, twc, False)[0]

# step5. Blend the 2d Gaussian to image
image = pg.forward(height, width, u, cov2d, alpha, depth, color)[0]
cinv2ds, areas = pg.inverseCov2D(cov2ds, False)
image = pg.splat(height, width, us, cinv2ds, alphas, depths, colors, areas)[0]
image = image.to('cpu').numpy()

plt.imshow(image.transpose([1, 2, 0]))
Expand Down
17 changes: 9 additions & 8 deletions gausplat.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,17 +183,18 @@ def project(pw, Rcw, tcw, K):
return u, pc


def splat(height, width, us, cov2d, alpha, depth, color):
image = np.zeros([height, width, 3])
image_T = np.ones([height, width])

def inverse_cov2d(cov2d):
# forward.md 5.3
# compute inverse of cov2d
det_inv = 1. / (cov2d[:, 0] * cov2d[:, 2] - cov2d[:, 1] * cov2d[:, 1] + 0.000001)
cov2d_inv = np.array([cov2d[:, 2] * det_inv, -cov2d[:, 1] * det_inv, cov2d[:, 0] * det_inv]).T

# Determine the drawing area of 2d Gaussian.
cinv2d = np.array([cov2d[:, 2] * det_inv, -cov2d[:, 1] * det_inv, cov2d[:, 0] * det_inv]).T
areas = 3 * np.sqrt(np.vstack([cov2d[:, 0], cov2d[:, 2]])).T
return cinv2d, areas


def splat(height, width, us, cinv2d, alpha, depth, color, areas):
image = np.zeros([height, width, 3])
image_T = np.ones([height, width])

start = time.time()

Expand Down Expand Up @@ -226,7 +227,7 @@ def splat(height, width, us, cov2d, alpha, depth, color):
if ((x1 - x0) * (y1 - y0) == 0):
continue

cinv = cov2d_inv[i]
cinv = cinv2d[i]
opa = alpha[i]
patch_color = color[i]

Expand Down
23 changes: 14 additions & 9 deletions pygausplat/ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,18 @@
#include <iostream>
#include <vector>

std::vector<torch::Tensor> forward(
const int H,
const int W,
const torch::Tensor u,
const torch::Tensor cov2d,
const torch::Tensor alpha,
const torch::Tensor depth,
const torch::Tensor color);
std::vector<torch::Tensor> splat(
const int height,
const int width,
const torch::Tensor us,
const torch::Tensor cinv2ds,
const torch::Tensor alphas,
const torch::Tensor depths,
const torch::Tensor colors,
const torch::Tensor areas);

std::vector<torch::Tensor> inverseCov2D(const torch::Tensor cov2ds,
const bool calc_J);

std::vector<torch::Tensor> computeCov3D(const torch::Tensor rots,
const torch::Tensor scales,
Expand Down Expand Up @@ -58,8 +62,9 @@ std::vector<torch::Tensor> sh2Color(const torch::Tensor shs,

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("forward", &forward, "create 2d image");
m.def("splat", &splat, "create 2d image");
m.def("backward", &backward, "compute jacobians");
m.def("inverseCov2D", &inverseCov2D, "inverse 2D covariances");
m.def("computeCov3D", &computeCov3D, "compute 3D covariances");
m.def("computeCov2D", &computeCov2D, "compute 2D covariances");
m.def("project", &project, "project point to image");
Expand Down
45 changes: 31 additions & 14 deletions pygausplat/forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@



std::vector<torch::Tensor> forward(
std::vector<torch::Tensor> splat(
const int height,
const int width,
const torch::Tensor us,
const torch::Tensor cov2d,
const torch::Tensor cinv2ds,
const torch::Tensor alphas,
const torch::Tensor depths,
const torch::Tensor colors)
const torch::Tensor colors,
const torch::Tensor areas)
{
auto float_opts = us.options().dtype(torch::kFloat32);
auto int_opts = us.options().dtype(torch::kInt32);
Expand All @@ -38,22 +39,13 @@ std::vector<torch::Tensor> forward(
thrust::device_vector<uint4> gs_rects(gs_num);
thrust::device_vector<uint> patch_num_per_gs(gs_num);
thrust::device_vector<uint> patch_offset_per_gs(gs_num);
thrust::device_vector<float> cinv2d(gs_num * 3);
thrust::device_vector<float> areas(gs_num * 2);

inverseCov2D<<<DIV_ROUND_UP(gs_num, BLOCK_SIZE), BLOCK_SIZE>>>(
gs_num,
cov2d.contiguous().data_ptr<float>(),
thrust::raw_pointer_cast(cinv2d.data()),
thrust::raw_pointer_cast(areas.data()));
cudaDeviceSynchronize();

getRect<<<DIV_ROUND_UP(gs_num, BLOCK_SIZE), BLOCK_SIZE>>>(
width,
height,
gs_num,
us.contiguous().data_ptr<float>(),
thrust::raw_pointer_cast(areas.data()),
areas.contiguous().data_ptr<float>(),
depths.contiguous().data_ptr<float>(),
grid,
thrust::raw_pointer_cast(gs_rects.data()),
Expand Down Expand Up @@ -95,7 +87,7 @@ std::vector<torch::Tensor> forward(
patch_range_per_tile.contiguous().data_ptr<int>(),
thrust::raw_pointer_cast(gs_id_per_patch.data()),
us.contiguous().data_ptr<float>(),
thrust::raw_pointer_cast(cinv2d.data()),
cinv2ds.contiguous().data_ptr<float>(),
alphas.contiguous().data_ptr<float>(),
colors.contiguous().data_ptr<float>(),
image.contiguous().data_ptr<float>(),
Expand Down Expand Up @@ -289,3 +281,28 @@ std::vector<torch::Tensor> sh2Color(const torch::Tensor shs,
return {colors};
}
}

std::vector<torch::Tensor> inverseCov2D(const torch::Tensor cov2ds,
const bool calc_J)
{
auto float_opts = cov2ds.options().dtype(torch::kFloat32);
auto int_opts = cov2ds.options().dtype(torch::kInt32);
int gs_num = cov2ds.sizes()[0];
torch::Tensor cinv2ds = torch::full({gs_num, 3}, 0.0, float_opts);
torch::Tensor areas = torch::full({gs_num, 2}, 0.0, float_opts);

if (calc_J)
{
return {cinv2ds, areas};
}
else
{
inverseCov2D<<<DIV_ROUND_UP(gs_num, BLOCK_SIZE), BLOCK_SIZE>>>(
gs_num,
cov2ds.contiguous().data_ptr<float>(),
cinv2ds.contiguous().data_ptr<float>(),
areas.contiguous().data_ptr<float>());
cudaDeviceSynchronize();
return {cinv2ds, areas};
}
}
Loading

0 comments on commit ea712f4

Please sign in to comment.