Skip to content

Commit

Permalink
Add colormaps to logs
Browse files Browse the repository at this point in the history
  • Loading branch information
tancik committed May 7, 2022
1 parent c164fd5 commit cf9c557
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 6 deletions.
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
},
"editor.formatOnSave": true,
"python.formatting.provider": "black",
"python.formatting.blackArgs": [
"--line-length=120"
],
"python.linting.pylintEnabled": true,
"python.linting.flake8Enabled": false,
"python.linting.enabled": true,
Expand Down
1 change: 1 addition & 0 deletions environment/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ imageio==2.16.1
ipywidgets>=7.6
jupyterlab
kaleido
matplotlib
mediapy
plotly==5.7.0
pylint==2.13.4
Expand Down
9 changes: 7 additions & 2 deletions mattport/nerf/graph/vanilla_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from mattport.nerf.sampler import PDFSampler, UniformSampler
from mattport.structures import colors
from mattport.structures.rays import RaySamples
from mattport.utils import visualization


class NeRFField(nn.Module):
Expand Down Expand Up @@ -162,12 +163,16 @@ def forward(self, ray_indices: TensorType["num_rays", 3]):
fine_renderer_accumulation = self.renderer_accumulation(fine_weights) # RendererOutputs
fine_renderer_disparity = self.renderer_disparity(fine_weights, pdf_ray_samples.ts)

# TODO refactor this into "vis" section. Doesn't need to be run during training.
coarse_renderer_accumulation = visualization.apply_colormap(coarse_renderer_accumulation.accumulation)
fine_renderer_accumulation = visualization.apply_colormap(fine_renderer_accumulation.accumulation)

# outputs:
outputs = {
"rgb_coarse": coarse_renderer_outputs.rgb,
"rgb_fine": fine_renderer_outputs.rgb,
"accumulation_coarse": coarse_renderer_accumulation.accumulation,
"accumulation_fine": fine_renderer_accumulation.accumulation,
"accumulation_coarse": coarse_renderer_accumulation,
"accumulation_fine": fine_renderer_accumulation,
"disparity_coarse": coarse_renderer_disparity.disparity,
"disparity_fine": fine_renderer_disparity.disparity,
}
Expand Down
4 changes: 2 additions & 2 deletions mattport/nerf/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,8 @@ def test_image(self, image_idx, step):
disparity_fine.append(graph_outputs["disparity_fine"])
rgb_coarse = torch.cat(rgb_coarse).view(image_height, image_width, 3).detach().cpu()
rgb_fine = torch.cat(rgb_fine).view(image_height, image_width, 3).detach().cpu()
accumulation_coarse = torch.cat(accumulation_coarse).view(image_height, image_width, 1).detach().cpu()
accumulation_fine = torch.cat(accumulation_fine).view(image_height, image_width, 1).detach().cpu()
accumulation_coarse = torch.cat(accumulation_coarse).view(image_height, image_width, 3).detach().cpu()
accumulation_fine = torch.cat(accumulation_fine).view(image_height, image_width, 3).detach().cpu()
disparity_coarse = torch.cat(disparity_coarse).view(image_height, image_width, 1).detach().cpu()
disparity_fine = torch.cat(disparity_fine).view(image_height, image_width, 1).detach().cpu()

Expand Down
24 changes: 24 additions & 0 deletions mattport/utils/visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
""" Helper functions for visualizing outputs """

import torch
from matplotlib import cm
from torchtyping import TensorType


def apply_colormap(image: TensorType[..., 1], cmap="viridis") -> TensorType[..., 3]:
"""Convert single channel to a color image.
Args:
image (TensorType[..., 1]): Single channel image.
cmap (str, optional): Colormap for image. Defaults to 'turbo'.
Returns:
TensorType[..., 3]: Colored image
"""

colormap = cm.get_cmap(cmap)
colormap = torch.tensor(colormap.colors).to(image.device)

image = (image * 255).long()

return colormap[image]
4 changes: 2 additions & 2 deletions mattport/utils/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,13 @@ def __init__(self, is_main_thread: bool, save_dir: str):
self.tb_writer = SummaryWriter(log_dir=self.save_dir)

def write_image(
self, name: str, x: TensorType["H", "W", 3], step: int, group: str = None, prefix: str = None
self, name: str, x: TensorType["H", "W", "C"], step: int, group: str = None, prefix: str = None
) -> None:
"""_summary_
Args:
name (str): data identifier
x (TensorType["H", "W", 3]): rendered image to write
x (TensorType["H", "W", "C"]): rendered image to write
"""
x = to8b(x)
tensorboard_name = get_tensorboard_name(name, group=group, prefix=prefix)
Expand Down

0 comments on commit cf9c557

Please sign in to comment.