Skip to content

Commit

Permalink
Add disparity maps
Browse files Browse the repository at this point in the history
  • Loading branch information
tancik committed May 5, 2022
1 parent 0b82c11 commit 2268a91
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 1 deletion.
7 changes: 6 additions & 1 deletion mattport/nerf/graph/vanilla_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from mattport.nerf.field_modules.ray_generator import RayGenerator
from mattport.nerf.graph.base import Graph
from mattport.nerf.loss import MSELoss
from mattport.nerf.renderers import AccumulationRenderer, RGBRenderer
from mattport.nerf.renderers import AccumulationRenderer, DisparityRenderer, RGBRenderer
from mattport.nerf.sampler import PDFSampler, UniformSampler
from mattport.structures import colors
from mattport.structures.rays import RaySamples
Expand Down Expand Up @@ -113,6 +113,7 @@ def populate_modules(self):
# renderers
self.renderer_rgb = RGBRenderer(background_color=colors.WHITE)
self.renderer_accumulation = AccumulationRenderer()
self.renderer_disparity = DisparityRenderer()

# losses
self.rgb_loss = MSELoss()
Expand Down Expand Up @@ -146,6 +147,7 @@ def forward(self, ray_indices: TensorType["num_rays", 3]):
weights=coarse_weights,
) # RendererOutputs
coarse_renderer_accumulation = self.renderer_accumulation(coarse_weights) # RendererOutputs
coarse_renderer_disparity = self.renderer_disparity(coarse_weights, uniform_ray_samples.ts)

# fine network:
pdf_ray_samples = self.sampler_pdf(uniform_ray_samples, coarse_weights) # RaySamples
Expand All @@ -158,13 +160,16 @@ def forward(self, ray_indices: TensorType["num_rays", 3]):
weights=fine_weights,
) # RendererOutputs
fine_renderer_accumulation = self.renderer_accumulation(fine_weights) # RendererOutputs
fine_renderer_disparity = self.renderer_disparity(fine_weights, pdf_ray_samples.ts)

# outputs:
outputs = {
"rgb_coarse": coarse_renderer_outputs.rgb,
"rgb_fine": fine_renderer_outputs.rgb,
"accumulation_coarse": coarse_renderer_accumulation.accumulation,
"accumulation_fine": fine_renderer_accumulation.accumulation,
"disparity_coarse": coarse_renderer_disparity.disparity,
"disparity_fine": fine_renderer_disparity.disparity,
}
return outputs

Expand Down
35 changes: 35 additions & 0 deletions mattport/nerf/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class RendererOutputs:
rgb: TensorType["num_rays", 3] = None
density: TensorType["num_rays", 1] = None
accumulation: TensorType["num_rays", 1] = None
disparity: TensorType["num_rays", 1] = None


class RGBRenderer(nn.Module):
Expand Down Expand Up @@ -79,3 +80,37 @@ def forward(

renderer_outputs = RendererOutputs(accumulation=accumulation)
return renderer_outputs


class DisparityRenderer(nn.Module):
"""Calcualte depth along ray."""

def __init__(self, method: str = "expected") -> None:
"""
Args:
method (str, optional): Depth calculation method. Defaults to 'expected'.
"""
super().__init__()
if method not in {"expected"}:
raise ValueError(f"{method} is an invalid depth calculation method")
self.method = method

def forward(self, weights: TensorType[..., "num_samples"], ts: TensorType[..., "num_samples"]) -> RendererOutputs:
"""Composite samples along ray and calculate disparities.
Args:
weights (TensorType[..., "num_samples"]): Weights for each sample
ts (TensorType[..., "num_samples"]): Sample locations along rays
Returns:
RendererOutputs: Outputs with disparity values.
"""

if self.method == "expected":
depth = torch.sum(weights * ts, dim=-1)
eps = 1e-10
disparity = 1.0 / torch.max(eps * torch.ones_like(depth), depth / torch.sum(weights, -1))

return RendererOutputs(disparity=disparity)

raise NotImplementedError(f"Method {self.method} not implemented")
9 changes: 9 additions & 0 deletions mattport/nerf/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,24 +247,33 @@ def test_image(self, image_idx, step):
rgb_fine = []
accumulation_coarse = []
accumulation_fine = []
disparity_coarse = []
disparity_fine = []
for i in range(0, num_rays, chunk_size):
ray_indices = all_ray_indices[i : i + chunk_size].to(self.device)
graph_outputs = self.graph(ray_indices)
rgb_coarse.append(graph_outputs["rgb_coarse"])
rgb_fine.append(graph_outputs["rgb_fine"])
accumulation_coarse.append(graph_outputs["accumulation_coarse"])
accumulation_fine.append(graph_outputs["accumulation_fine"])
disparity_coarse.append(graph_outputs["disparity_coarse"])
disparity_fine.append(graph_outputs["disparity_fine"])
rgb_coarse = torch.cat(rgb_coarse).view(image_height, image_width, 3).detach().cpu()
rgb_fine = torch.cat(rgb_fine).view(image_height, image_width, 3).detach().cpu()
accumulation_coarse = torch.cat(accumulation_coarse).view(image_height, image_width, 1).detach().cpu()
accumulation_fine = torch.cat(accumulation_fine).view(image_height, image_width, 1).detach().cpu()
disparity_coarse = torch.cat(disparity_coarse).view(image_height, image_width, 1).detach().cpu()
disparity_fine = torch.cat(disparity_fine).view(image_height, image_width, 1).detach().cpu()

combined_image = torch.cat([image, rgb_coarse, rgb_fine], dim=1)
self.writer.write_image(f"image_idx_{image_idx}-rgb_coarse_fine", combined_image, step, group="val_img")

combined_image = torch.cat([accumulation_coarse, accumulation_fine], dim=1)
self.writer.write_image(f"image_idx_{image_idx}", combined_image, step, group="val_accumulation")

combined_image = torch.cat([disparity_coarse, disparity_fine], dim=1)
self.writer.write_image(f"image_idx_{image_idx}", combined_image, step, group="val_disparity")

coarse_psnr = get_psnr(image, rgb_coarse)
self.writer.write_scalar(f"image_idx_{image_idx}", float(coarse_psnr), step, group="val")

Expand Down

0 comments on commit 2268a91

Please sign in to comment.