Skip to content

Commit

Permalink
reorganization
Browse files Browse the repository at this point in the history
  • Loading branch information
scomup committed May 22, 2024
1 parent eb382b0 commit 11a3549
Show file tree
Hide file tree
Showing 25 changed files with 547 additions and 56 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ for Real-Time Radiance Field Rendering](https://repo-sam.inria.fr/fungraph/3d-ga

```bash
pip3 install -r requirements.txt
pip3 install pygausplat/.
pip3 install gsplatcu/.
```

## Forward process (render image)
Expand Down
2 changes: 1 addition & 1 deletion backward_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import torch.nn as nn
import numpy as np
from sh_coef import *
from gsplat.sh_coef import *


def upper_triangular(mat):
Expand Down
18 changes: 9 additions & 9 deletions backward_gpu.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import matplotlib.pyplot as plt
import torch
import pygausplat as pg
import gsplatcu as gsc
import numpy as np
from sh_coef import *
from gsplat.sh_coef import *
from backward_cpu import *


Expand Down Expand Up @@ -105,41 +105,41 @@
tcw_gpu = torch.from_numpy(tcw).type(torch.float32).to('cuda')
twc_gpu = torch.from_numpy(twc).type(torch.float32).to('cuda')

us_gpu, pcs_gpu, du_dpcs_gpu = pg.project(pws_gpu, Rcw_gpu, tcw_gpu, fx, fy, cx, cy, True)
us_gpu, pcs_gpu, du_dpcs_gpu = gsc.project(pws_gpu, Rcw_gpu, tcw_gpu, fx, fy, cx, cy, True)
print("%s test us_gpu" % check(us_gpu.cpu().numpy(), us))
print("%s test pcs_gpu" % check(pcs_gpu.cpu().numpy(), pcs))
print("%s test du_dpcs_gpu" % check(du_dpcs_gpu.cpu().numpy(), du_dpcs))

cov3ds_gpu, dcov3d_drots_gpu, dcov3d_dscales_gpu = pg.computeCov3D(rots_gpu, scales_gpu, True)
cov3ds_gpu, dcov3d_drots_gpu, dcov3d_dscales_gpu = gsc.computeCov3D(rots_gpu, scales_gpu, True)
print("%s test cov3ds_gpu" % check(cov3ds_gpu.cpu().numpy(), cov3ds))
print("%s test dcov3d_drots_gpu" % check(dcov3d_drots_gpu.cpu().numpy(), dcov3d_drots))
print("%s test dcov3d_dscales_gpu" % check(dcov3d_dscales_gpu.cpu().numpy(), dcov3d_dscales))

cov2ds_gpu, dcov2d_dcov3ds_gpu, dcov2d_dpcs_gpu = pg.computeCov2D(cov3ds_gpu, pcs_gpu, Rcw_gpu, fx, fy, True)
cov2ds_gpu, dcov2d_dcov3ds_gpu, dcov2d_dpcs_gpu = gsc.computeCov2D(cov3ds_gpu, pcs_gpu, Rcw_gpu, fx, fy, True)
print("%s test cov2ds_gpu" % check(cov2ds_gpu.cpu().numpy(), cov2ds))
print("%s test dcov2d_dcov3ds_gpu" % check(dcov2d_dcov3ds_gpu.cpu().numpy(), dcov2d_dcov3ds))
print("%s test dcov2d_dpcs_gpu" % check(dcov2d_dpcs_gpu.cpu().numpy(), dcov2d_dpcs))

colors_gpu, dcolor_dshs_gpu, dcolor_dpws_gpu = pg.sh2Color(shs_gpu, pws_gpu, twc_gpu, True)
colors_gpu, dcolor_dshs_gpu, dcolor_dpws_gpu = gsc.sh2Color(shs_gpu, pws_gpu, twc_gpu, True)
print("%s test colors_gpu" % check(colors_gpu.cpu().numpy(), colors))
print("%s test dcolor_dshs_gpu" % check(dcolor_dshs_gpu.cpu().numpy(), dcolor_dshs))
print("%s test dcolor_dshs_gpu" % check(dcolor_dpws_gpu.cpu().numpy(), dcolor_dpws))

cinv2ds_gpu, areas_gpu, dcinv2d_dcov2ds_gpu = pg.inverseCov2D(cov2ds_gpu, True)
cinv2ds_gpu, areas_gpu, dcinv2d_dcov2ds_gpu = gsc.inverseCov2D(cov2ds_gpu, True)
print("%s test cinv2d_gpu" % check(cinv2ds_gpu.cpu().numpy(), cinv2ds))
print("%s test dcinv2d_dcov2ds_gpu" % check(dcinv2d_dcov2ds_gpu.cpu().numpy(), dcinv2d_dcov2ds))

depths_gpu = torch.from_numpy(np.array([1, 2, 3, 4])).type(torch.float32).to('cuda')
image_gpu, contrib_gpu, final_tau_gpu, patch_range_per_tile_gpu, gsid_per_patch_gpu =\
pg.splat(height, width, us_gpu, cinv2ds_gpu, alphas_gpu, depths_gpu, colors_gpu, areas_gpu)
gsc.splat(height, width, us_gpu, cinv2ds_gpu, alphas_gpu, depths_gpu, colors_gpu, areas_gpu)
print("%s test image_gpu" %
check(image_gpu.cpu().numpy(), image.transpose([2, 0, 1])))

_, dloss_dgammas = get_loss(image, image_gt)
dloss_dgammas_gpu = torch.from_numpy(dloss_dgammas).type(torch.float32).to('cuda')

dloss_dus_gpu, dloss_dcinv2ds_gpu, dloss_dalphas_gpu, dloss_dcolors_gpu =\
pg.splatB(height, width, us_gpu, cinv2ds_gpu, alphas_gpu, depths_gpu, colors_gpu,
gsc.splatB(height, width, us_gpu, cinv2ds_gpu, alphas_gpu, depths_gpu, colors_gpu,
contrib_gpu, final_tau_gpu, patch_range_per_tile_gpu, gsid_per_patch_gpu, dloss_dgammas_gpu)

dloss_dalphas = dloss_dalphas.reshape([gs_num, 1, 1])
Expand Down
3 changes: 2 additions & 1 deletion forward_cpu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from gausplat import *
from gsplat.gausplat import *
from gsplat.read_ply import *


if __name__ == "__main__":
Expand Down
16 changes: 8 additions & 8 deletions forward_gpu.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
import pygausplat as pg
import gsplatcu as gsc
import numpy as np
import matplotlib.pyplot as plt
from read_ply import *
from gsplat.read_ply import *


if __name__ == "__main__":
Expand Down Expand Up @@ -75,21 +75,21 @@

# step1. Transform pw to camera frame,
# and project it to iamge.
us, pcs = pg.project(pws, Rcw, tcw, focal_x, focal_y, center_x, center_y, False)
us, pcs = gsc.project(pws, Rcw, tcw, focal_x, focal_y, center_x, center_y, False)
depths = pcs[:, 2]

# step2. Calcuate the 3d Gaussian.
cov3ds = pg.computeCov3D(rots, scales, False)[0]
cov3ds = gsc.computeCov3D(rots, scales, False)[0]

# step3. Calcuate the 2d Gaussian.
cov2ds = pg.computeCov2D(cov3ds, pcs, Rcw, focal_x, focal_y, False)[0]
cov2ds = gsc.computeCov2D(cov3ds, pcs, Rcw, focal_x, focal_y, False)[0]

# step4. get color info
colors = pg.sh2Color(shs, pws, twc, False)[0]
colors = gsc.sh2Color(shs, pws, twc, False)[0]

# step5. Blend the 2d Gaussian to image
cinv2ds, areas = pg.inverseCov2D(cov2ds, False)
image = pg.splat(height, width, us, cinv2ds, alphas, depths, colors, areas)[0]
cinv2ds, areas = gsc.inverseCov2D(cov2ds, False)
image = gsc.splat(height, width, us, cinv2ds, alphas, depths, colors, areas)[0]
image = image.to('cpu').numpy()

plt.imshow(image.transpose([1, 2, 0]))
Expand Down
2 changes: 1 addition & 1 deletion gaussian_viewer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3

import numpy as np
from read_ply import *
from gsplat.read_ply import *

import sys
import os
Expand Down
252 changes: 252 additions & 0 deletions gsplat/gausplat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
import matplotlib.pyplot as plt
import numpy as np
import time
from gsplat.sh_coef import *


class Camera:
def __init__(self, id, width, height, K, Rcw, tcw):
self.id = id
self.width = width
self.height = height
self.K = K
self.Rcw = Rcw
self.tcw = tcw
self.cam_center = -np.linalg.inv(Rcw) @ tcw
self.focal_x = K[0, 0]
self.focal_y = K[1, 1]


def projection_matrix(focal_x, focal_y, width, height, z_near=0.1, z_far=100):
P = np.zeros([4, 4])
P[0, 0] = 2 * focal_x / width
P[1, 1] = 2 * focal_x / width
P[2, 2] = -(z_near + z_far) / (z_far - z_near)
P[2, 3] = -(2 * z_near * z_far) / (z_far - z_near)
P[3, 0] = -1
return P


def upper_triangular(mat):
s = mat[0].shape[0]
n = 0
if (s == 2):
n = 3
elif(s == 3):
n = 6
else:
raise NotImplementedError("no supported mat")
upper = np.zeros([mat.shape[0], n])
n = 0
for i in range(s):
for j in range(i, s):
upper[:, n] = mat[:, i, j]
n = n + 1
return upper


def symmetric_matrix(upper):
n = upper.shape[1]
if (n == 6):
s = 3
elif(n == 3):
s = 2
else:
raise NotImplementedError("no supported mat")

mat = np.zeros([upper.shape[0], s, s])

n = 0
for i in range(s):
for j in range(i, s):
mat[:, i, j] = upper[:, n]
if (i != j):
mat[:, j, i] = upper[:, n]
n = n + 1
return mat


def sh2color(sh, pw, twc):
sh_dim = sh.shape[1]
color = SH_C0_0 * sh[:, 0:3] + 0.5
if (sh_dim <= 3):
return color
ray_dir = pw - twc
ray_dir /= np.linalg.norm(ray_dir, axis=1)[:, np.newaxis]
x = ray_dir[:, 0][:, np.newaxis]
y = ray_dir[:, 1][:, np.newaxis]
z = ray_dir[:, 2][:, np.newaxis]

color = color + \
SH_C1_0 * y * sh[:, 3:6] + \
SH_C1_1 * z * sh[:, 6:9] + \
SH_C1_2 * x * sh[:, 9:12]

if (sh_dim <= 12):
return color
xx = x * x
yy = y * y
zz = z * z
xy = x * y
yz = y * z
xz = x * z
color = color + \
SH_C2_0 * xy * sh[:, 12:15] + \
SH_C2_1 * yz * sh[:, 15:18] + \
SH_C2_2 * (2.0 * zz - xx - yy) * sh[:, 18:21] + \
SH_C2_3 * xz * sh[:, 21:24] + \
SH_C2_4 * (xx - yy) * sh[:, 24:27]

if (sh_dim <= 27):
return color

color = color + \
SH_C3_0 * y * (3.0 * xx - yy) * sh[:, 27:30] + \
SH_C3_1 * xy * z * sh[:, 30:33] + \
SH_C3_2 * y * (4.0 * zz - xx - yy) * sh[:, 33:36] + \
SH_C3_3 * z * (2.0 * zz - 3.0 * xx - 3.0 * yy) * sh[:, 36:39] + \
SH_C3_4 * x * (4.0 * zz - xx - yy) * sh[:, 39:42] + \
SH_C3_5 * z * (xx - yy) * sh[:, 42:45] + \
SH_C3_6 * x * (xx - 3.0 * yy) * sh[:, 45:48]

return color


def compute_cov_3d(scale, rot):
# Create scaling matrix
S = np.zeros([scale.shape[0], 3, 3])
S[:, 0, 0] = scale[:, 0]
S[:, 1, 1] = scale[:, 1]
S[:, 2, 2] = scale[:, 2]

# Normalize quaternion to get valid rotation
q = rot
w = q[:, 0]
x = q[:, 1]
y = q[:, 2]
z = q[:, 3]

# Compute rotation matrix from quaternion
R = np.array([
[1.0 - 2*(y**2 + z**2), 2*(x*y - z*w), 2*(x * z + y * w)],
[2*(x*y + z*w), 1.0 - 2*(x**2 + z**2), 2*(y*z - x*w)],
[2*(x*z - y*w), 2*(y*z + x*w), 1.0 - 2*(x**2 + y**2)]
]).transpose(2, 0, 1)
M = R @ S

# Compute 3D world covariance matrix Sigma
Sigma = M @ M.transpose(0, 2, 1)
cov3d = upper_triangular(Sigma)

return cov3d


def compute_cov_2d(pc, K, cov3d, Rcw):
x = pc[:, 0]
y = pc[:, 1]
z = pc[:, 2]
focal_x = K[0, 0]
focal_y = K[1, 1]
# c_x = K[0, 2]
# c_y = K[1, 2]
# u_ndc = (u - np.array([c_x, c_y]))/np.array([2*c_x, 2*c_y])
# out_idx = np.where(np.max(u_ndc, axis=1) > 1.3)[0]
J = np.zeros([pc.shape[0], 3, 3])
z2 = z * z
J[:, 0, 0] = focal_x / z
J[:, 0, 2] = -(focal_x * x) / z2
J[:, 1, 1] = focal_y / z
J[:, 1, 2] = -(focal_y * y) / z2

T = J @ Rcw

Sigma = symmetric_matrix(cov3d)

Sigma_prime = T @ Sigma @ T.transpose(0, 2, 1)
Sigma_prime[:, 0, 0] += 0.3
Sigma_prime[:, 1, 1] += 0.3

cov2d = upper_triangular(Sigma_prime[:, :2, :2])
# cov2d[out_idx] = 0
return cov2d


def project(pw, Rcw, tcw, K):
# project the mean of 2d gaussian to image.
# forward.md (1.1) (1.2)
pc = (Rcw @ pw.T).T + tcw
depth = pc[:, 2]
pc_proj = (K @ pc.T).T
pc_proj /= pc_proj[:, 2][:, np.newaxis]
u = pc_proj[:, :2]
return u, pc


def inverse_cov2d(cov2d):
# forward.md 5.3
# compute inverse of cov2d
det_inv = 1. / (cov2d[:, 0] * cov2d[:, 2] - cov2d[:, 1] * cov2d[:, 1] + 0.000001)
cinv2d = np.array([cov2d[:, 2] * det_inv, -cov2d[:, 1] * det_inv, cov2d[:, 0] * det_inv]).T
areas = 3 * np.sqrt(np.vstack([cov2d[:, 0], cov2d[:, 2]])).T
return cinv2d, areas


def splat(height, width, us, cinv2d, alpha, depth, color, areas, im=None):
image = np.zeros([height, width, 3])
image_T = np.ones([height, width])

start = time.time()

# sort by depth
sort_idx = np.argsort(depth)

idx_map = np.array((np.meshgrid(np.arange(0, width), np.arange(0, height))))
win_size = np.array([width, height])

for j, i in enumerate(sort_idx):
if (j % 10000 == 0):
print("processing... %3.f%%" % (j / float(us.shape[0]) * 100.))
if (im is not None):
im.set_data(image)
plt.pause(0.1)

if (depth[i] < 0.2 or depth[i] > 100):
continue

u = us[i]
if (np.any(np.abs(u / win_size) > 1.3)):
continue

r = areas[i]
x0 = int(np.maximum(np.minimum(u[0] - r[0], width), 0))
x1 = int(np.maximum(np.minimum(u[0] + r[0], width), 0))
y0 = int(np.maximum(np.minimum(u[1] - r[1], height), 0))
y1 = int(np.maximum(np.minimum(u[1] + r[1], height), 0))

if ((x1 - x0) * (y1 - y0) == 0):
continue

cinv = cinv2d[i]
opa = alpha[i]
patch_color = color[i]

d = u[:, np.newaxis, np.newaxis] - idx_map[:, y0:y1, x0:x1]
# mahalanobis distance
maha_dist = cinv[0] * d[0] * d[0] + cinv[2] * d[1] * d[1] + 2 * cinv[1] * d[0] * d[1]
patch_alpha = np.exp(-0.5 * maha_dist) * opa
patch_alpha[patch_alpha > 0.99] = 0.99

# draw inverse gaussian
# th = 0.7
# patch_alpha = np.exp(-0.5 * maha_dist) * opa
# patch_alpha[patch_alpha <= th] = 0
# patch_alpha[patch_alpha > th] = (1 - patch_alpha[patch_alpha > th])

T = image_T[y0:y1, x0:x1]
image[y0:y1, x0:x1, :] += (patch_alpha * T)[:, :, np.newaxis] * patch_color
image_T[y0:y1, x0:x1] = T * (1 - patch_alpha)
end = time.time()
time_diff = end - start
print("add patch time %f\n" % time_diff)

return image
Loading

0 comments on commit 11a3549

Please sign in to comment.