Skip to content

Commit

Permalink
No more activation inside, guardbands removed
Browse files Browse the repository at this point in the history
  • Loading branch information
bkerbl committed Jul 2, 2023
1 parent feecabd commit 3a07ac2
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 44 deletions.
2 changes: 1 addition & 1 deletion cuda_rasterizer/auxiliary.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ __forceinline__ __device__ bool in_frustum(int idx,
float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w };
p_view = transformPoint4x3(p_orig, viewmatrix);

if (p_view.z <= 0.2f || ((p_proj.x < -1.3 || p_proj.x > 1.3 || p_proj.y < -1.3 || p_proj.y > 1.3)))
if (p_view.z <= 0.2f)// || ((p_proj.x < -1.3 || p_proj.x > 1.3 || p_proj.y < -1.3 || p_proj.y > 1.3)))
{
if (prefiltered)
{
Expand Down
41 changes: 27 additions & 14 deletions cuda_rasterizer/backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ __global__ void computeCov2DCUDA(int P,
const float3* means,
const int* radii,
const float* cov3Ds,
float h_x,
float h_y,
const float h_x, float h_y,
const float tan_fovx, float tan_fovy,
const float* view_matrix,
const float* dL_dconics,
float3* dL_dmeans,
Expand All @@ -153,11 +153,20 @@ __global__ void computeCov2DCUDA(int P,
float3 mean = means[idx];
float3 dL_dconic = { dL_dconics[4 * idx], dL_dconics[4 * idx + 1], dL_dconics[4 * idx + 3] };
float3 t = transformPoint4x3(mean, view_matrix);
float t_inv_norm = 1.f / sqrt(t.x * t.x + t.y * t.y + t.z * t.z);

const float limx = 1.3f * tan_fovx;
const float limy = 1.3f * tan_fovy;
const float txtz = t.x / t.z;
const float tytz = t.y / t.z;
t.x = min(limx, max(-limx, txtz)) * t.z;
t.y = min(limy, max(-limy, tytz)) * t.z;

const float x_grad_mul = txtz < -limx || txtz > limx ? 0 : 1;
const float y_grad_mul = tytz < -limy || tytz > limy ? 0 : 1;

glm::mat3 J = glm::mat3(h_x / t.z, 0.0f, -(h_x * t.x) / (t.z * t.z),
0.0f, h_y / t.z, -(h_y * t.y) / (t.z * t.z),
t.x * t_inv_norm, t.y * t_inv_norm, t.z * t_inv_norm);
0, 0, 0);

glm::mat3 W = glm::mat3(
view_matrix[0], view_matrix[4], view_matrix[8],
Expand Down Expand Up @@ -239,8 +248,8 @@ __global__ void computeCov2DCUDA(int P,
float tz3 = tz2 * tz;

// Gradients of loss w.r.t. transformed Gaussian mean t
float dL_dtx = -h_x * tz2 * dL_dJ02;
float dL_dty = -h_y * tz2 * dL_dJ12;
float dL_dtx = x_grad_mul * -h_x * tz2 * dL_dJ02;
float dL_dty = y_grad_mul * -h_y * tz2 * dL_dJ12;
float dL_dtz = -h_x * tz2 * dL_dJ00 - h_y * tz2 * dL_dJ11 + (2 * h_x * t.x) * tz3 * dL_dJ02 + (2 * h_y * t.y) * tz3 * dL_dJ12;

// Account for transformation of mean to t
Expand All @@ -258,7 +267,7 @@ __global__ void computeCov2DCUDA(int P,
__device__ void computeCov3D(int idx, const glm::vec3 scale, float mod, const glm::vec4 rot, const float* dL_dcov3Ds, glm::vec3* dL_dscales, glm::vec4* dL_drots)
{
// Recompute (intermediate) results for the 3D covariance computation.
glm::vec4 q = rot / glm::length(rot);
glm::vec4 q = rot;// / glm::length(rot);
float r = q.x;
float x = q.y;
float y = q.z;
Expand All @@ -272,7 +281,7 @@ __device__ void computeCov3D(int idx, const glm::vec3 scale, float mod, const gl

glm::mat3 S = glm::mat3(1.0f);

glm::vec3 s = mod * exp(scale);
glm::vec3 s = mod * scale;
S[0][0] = s.x;
S[1][1] = s.y;
S[2][2] = s.z;
Expand All @@ -298,16 +307,16 @@ __device__ void computeCov3D(int idx, const glm::vec3 scale, float mod, const gl
glm::mat3 Rt = glm::transpose(R);
glm::mat3 dL_dMt = glm::transpose(dL_dM);

dL_dMt[0] *= s.x;
dL_dMt[1] *= s.y;
dL_dMt[2] *= s.z;

// Gradients of loss w.r.t. scale
glm::vec3* dL_dscale = dL_dscales + idx;
dL_dscale->x = glm::dot(Rt[0], dL_dMt[0]);
dL_dscale->y = glm::dot(Rt[1], dL_dMt[1]);
dL_dscale->z = glm::dot(Rt[2], dL_dMt[2]);

dL_dMt[0] *= s.x;
dL_dMt[1] *= s.y;
dL_dMt[2] *= s.z;

// Gradients of loss w.r.t. normalized quaternion
glm::vec4 dL_dq;
dL_dq.x = 2 * z * (dL_dMt[0][1] - dL_dMt[1][0]) + 2 * y * (dL_dMt[2][0] - dL_dMt[0][2]) + 2 * x * (dL_dMt[1][2] - dL_dMt[2][1]);
Expand All @@ -317,7 +326,7 @@ __device__ void computeCov3D(int idx, const glm::vec3 scale, float mod, const gl

// Gradients of loss w.r.t. unnormalized quaternion
float4* dL_drot = (float4*)(dL_drots + idx);
*dL_drot = dnormvdv(float4{ rot.x, rot.y, rot.z, rot.w }, float4{ dL_dq.x, dL_dq.y, dL_dq.z, dL_dq.w });
*dL_drot = float4{ dL_dq.x, dL_dq.y, dL_dq.z, dL_dq.w };//dnormvdv(float4{ rot.x, rot.y, rot.z, rot.w }, float4{ dL_dq.x, dL_dq.y, dL_dq.z, dL_dq.w });
}

// Backward pass of the preprocessing steps, except
Expand Down Expand Up @@ -377,7 +386,8 @@ __global__ void preprocessCUDA(

// Backward version of the rendering procedure.
template <uint32_t C>
__global__ void renderCUDA(
__global__ void __launch_bounds__(BLOCK_X * BLOCK_Y)
renderCUDA(
const uint2* __restrict__ ranges,
const uint32_t* __restrict__ point_list,
int W, int H,
Expand Down Expand Up @@ -548,6 +558,7 @@ void BACKWARD::preprocess(
const float* viewmatrix,
const float* projmatrix,
const float focal_x, float focal_y,
const float tan_fovx, float tan_fovy,
const glm::vec3* campos,
const float3* dL_dmean2D,
const float* dL_dconic,
Expand All @@ -569,6 +580,8 @@ void BACKWARD::preprocess(
cov3Ds,
focal_x,
focal_y,
tan_fovx,
tan_fovy,
viewmatrix,
dL_dconic,
(float3*)dL_dmean3D,
Expand Down
1 change: 1 addition & 0 deletions cuda_rasterizer/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ namespace BACKWARD
const float* view,
const float* proj,
const float focal_x, float focal_y,
const float tan_fovx, float tan_fovy,
const glm::vec3* campos,
const float3* dL_dmean2D,
const float* dL_dconics,
Expand Down
41 changes: 18 additions & 23 deletions cuda_rasterizer/forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,25 @@ __device__ glm::vec3 computeColorFromSH(int idx, int deg, int max_coeffs, const
}

// Forward version of 2D covariance matrix computation
__device__ float3 computeCov2D(const float3& mean, float focal_x, float focal_y, const float* cov3D, const float* viewmatrix)
__device__ float3 computeCov2D(const float3& mean, float focal_x, float focal_y, float tan_fovx, float tan_fovy, const float* cov3D, const float* viewmatrix)
{
// The following models the steps outlined by equations 29
// and 31 in "EWA Splatting" (Zwicker et al., 2002).
// Additionally considers aspect / scaling of viewport.
// Transposes used to account for row-/column-major conventions.
float3 t = transformPoint4x3(mean, viewmatrix);

float t_inv_norm = 1.f / sqrt(t.x * t.x + t.y * t.y + t.z * t.z);
const float limx = 1.3f * tan_fovx;
const float limy = 1.3f * tan_fovy;
const float txtz = t.x / t.z;
const float tytz = t.y / t.z;
t.x = min(limx, max(-limx, txtz)) * t.z;
t.y = min(limy, max(-limy, tytz)) * t.z;

glm::mat3 J = glm::mat3(
focal_x / t.z, 0.0f, -(focal_x * t.x) / (t.z * t.z),
0.0f, focal_y / t.z, -(focal_y * t.y) / (t.z * t.z),
t.x * t_inv_norm, t.y * t_inv_norm, t.z * t_inv_norm);
0, 0, 0);

glm::mat3 W = glm::mat3(
viewmatrix[0], viewmatrix[4], viewmatrix[8],
Expand All @@ -98,17 +103,17 @@ __device__ float3 computeCov2D(const float3& mean, float focal_x, float focal_y,

// Forward method for converting scale and rotation properties of each
// Gaussian to a 3D covariance matrix in world space. Also takes care
// of quaternion normalization and scale activation via exp.
// of quaternion normalization.
__device__ void computeCov3D(const glm::vec3 scale, float mod, const glm::vec4 rot, float* cov3D)
{
// Create scaling matrix
glm::mat3 S = glm::mat3(1.0f);
S[0][0] = mod * exp(scale.x);
S[1][1] = mod * exp(scale.y);
S[2][2] = mod * exp(scale.z);
S[0][0] = mod * scale.x;
S[1][1] = mod * scale.y;
S[2][2] = mod * scale.z;

// Normalize quaternion to get valid rotation
glm::vec4 q = rot / glm::length(rot);
glm::vec4 q = rot;// / glm::length(rot);
float r = q.x;
float x = q.y;
float y = q.z;
Expand Down Expand Up @@ -172,7 +177,7 @@ __global__ void preprocessCUDA(int P, int D, int M,
radii[idx] = 0;
tiles_touched[idx] = 0;

// Perform near and frustum culling with guardband, quit if outside.
// Perform near culling, quit if outside.
float3 p_view;
if (!in_frustum(idx, orig_points, viewmatrix, projmatrix, prefiltered, p_view))
return;
Expand All @@ -196,11 +201,8 @@ __global__ void preprocessCUDA(int P, int D, int M,
cov3D = cov3Ds + idx * 6;
}

// Compute max extent of Gaussian for fine-grained fustum culling
float max_dist2 = 9.f * max(cov3D[0], max(cov3D[3], cov3D[5]));

// Compute 2D screen-space covariance matrix
float3 cov = computeCov2D(p_orig, focal_x, focal_y, cov3D, viewmatrix);
float3 cov = computeCov2D(p_orig, focal_x, focal_y, tan_fovx, tan_fovy, cov3D, viewmatrix);

// Invert covariance (EWA algorithm)
float det = (cov.x * cov.z - cov.y * cov.y);
Expand All @@ -209,14 +211,6 @@ __global__ void preprocessCUDA(int P, int D, int M,
float det_inv = 1.f / det;
float3 conic = { cov.z * det_inv, -cov.y * det_inv, cov.x * det_inv };

// Fine-grained frustum culling against ellipsoid
float z_at_point = p_view.z + sqrt(max_dist2);
float x_to_border = z_at_point * tan_fovx;
float y_to_border = z_at_point * tan_fovy;
float D2_point = p_view.x * p_view.x + p_view.y * p_view.y;
if (D2_point - (x_to_border * x_to_border + y_to_border * y_to_border) > max_dist2)
return;

// Compute extent in screen space (by finding eigenvalues of
// 2D covariance matrix). Use extent to compute a bounding rectangle
// of screen-space tiles that this Gaussian overlaps with. Quit if
Expand Down Expand Up @@ -254,7 +248,8 @@ __global__ void preprocessCUDA(int P, int D, int M,
// block, each thread treats one pixel. Alternates between fetching
// and rasterizing data.
template <uint32_t CHANNELS>
__global__ void renderCUDA(
__global__ void __launch_bounds__(BLOCK_X * BLOCK_Y)
renderCUDA(
const uint2* __restrict__ ranges,
const uint32_t* __restrict__ point_list,
int W, int H,
Expand Down Expand Up @@ -407,8 +402,8 @@ void FORWARD::preprocess(int P, int D, int M,
const float* projmatrix,
const glm::vec3* cam_pos,
const int W, int H,
const float tan_fovx, float tan_fovy,
const float focal_x, float focal_y,
const float tan_fovx, float tan_fovy,
int* radii,
float2* means2D,
float* depths,
Expand Down
2 changes: 1 addition & 1 deletion cuda_rasterizer/forward.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ namespace FORWARD
const float* projmatrix,
const glm::vec3* cam_pos,
const int W, int H,
const float tan_fovx, float tan_fovy,
const float focal_x, float focal_y,
const float tan_fovx, float tan_fovy,
int* radii,
float2* points_xy_image,
float* depths,
Expand Down
3 changes: 2 additions & 1 deletion cuda_rasterizer/rasterizer_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,8 @@ int CudaRasterizer::Rasterizer::forward(
viewmatrix, projmatrix,
(glm::vec3*)cam_pos,
width, height,
tan_fovx, tan_fovy,
focal_x, focal_y,
tan_fovx, tan_fovy,
radii,
geomState.means2D,
geomState.depths,
Expand Down Expand Up @@ -408,6 +408,7 @@ void CudaRasterizer::Rasterizer::backward(
viewmatrix,
projmatrix,
focal_x, focal_y,
tan_fovx, tan_fovy,
(glm::vec3*)campos,
(float3*)dL_dmean2D,
dL_dconic,
Expand Down
7 changes: 3 additions & 4 deletions rasterize_points.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,13 @@ RasterizeGaussiansCUDA(
}

const int P = means3D.size(0);
const int N = 1; // batch size hard-coded
const int H = image_height;
const int W = image_width;

auto int_opts = means3D.options().dtype(torch::kInt32);
auto float_opts = means3D.options().dtype(torch::kFloat32);

torch::Tensor out_color = torch::full({N, NUM_CHANNELS, H, W}, 0.0, float_opts);
torch::Tensor out_color = torch::full({NUM_CHANNELS, H, W}, 0.0, float_opts);
torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32));

torch::Device device(torch::kCUDA);
Expand Down Expand Up @@ -126,8 +125,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
const torch::Tensor& imageBuffer)
{
const int P = means3D.size(0);
const int H = dL_dout_color.size(2);
const int W = dL_dout_color.size(3);
const int H = dL_dout_color.size(1);
const int W = dL_dout_color.size(2);

int M = 0;
if(sh.size(0) != 0)
Expand Down

0 comments on commit 3a07ac2

Please sign in to comment.