Skip to content

Commit

Permalink
Use torchmetrics (autonomousvision#68)
Browse files Browse the repository at this point in the history
  • Loading branch information
tancik authored May 11, 2022
1 parent a9b0697 commit 2ebe34f
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 100 deletions.
1 change: 0 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,4 @@
".vscode/*.py",
"external/**/*.py",
],
"python.analysis.typeCheckingMode": "basic",
}
2 changes: 1 addition & 1 deletion environment/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ mediapy
plotly==5.7.0
pylint==2.13.4
pytest
pytorch-msssim>=0.2.1
tensorboard==2.8.0
torch==1.11.0
torchmetrics[image]==0.8.2
torchtyping
torchvision==0.12.0
tqdm
73 changes: 38 additions & 35 deletions mattport/nerf/graph/vanilla_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import torch
from torch import nn
from torch.nn import Parameter
from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

from mattport.nerf import metrics
from mattport.nerf.field_modules.encoding import NeRFEncoding
from mattport.nerf.field_modules.field_heads import DensityFieldHead, FieldHeadNames, RGBFieldHead
from mattport.nerf.field_modules.mlp import MLP
Expand Down Expand Up @@ -123,6 +124,11 @@ def populate_modules(self):
# losses
self.rgb_loss = MSELoss()

# metrics
self.psnr = PeakSignalNoiseRatio(data_range=1.0)
self.ssim = StructuralSimilarityIndexMeasure()
self.lpips = LearnedPerceptualImagePatchSimilarity()

def get_param_groups(self) -> Dict[str, List[Parameter]]:
"""Obtain the parameter groups for the optimizers
Expand Down Expand Up @@ -164,30 +170,14 @@ def get_outputs(self, ray_bundle: RayBundle):
fine_renderer_accumulation = self.renderer_accumulation(fine_weights) # RendererOutputs
fine_renderer_depth = self.renderer_depth(fine_weights, pdf_ray_samples.ts)

# TODO refactor this into "vis" section. Doesn't need to be run during training.
coarse_renderer_depth = visualization.apply_depth_colormap(
coarse_renderer_depth.depth,
accumulation=coarse_renderer_accumulation.accumulation,
near_plane=self.near_plane,
far_plane=self.far_plane,
)
fine_renderer_depth = visualization.apply_depth_colormap(
fine_renderer_depth.depth,
accumulation=fine_renderer_accumulation.accumulation,
near_plane=self.near_plane,
far_plane=self.far_plane,
)
coarse_renderer_accumulation = visualization.apply_colormap(coarse_renderer_accumulation.accumulation)
fine_renderer_accumulation = visualization.apply_colormap(fine_renderer_accumulation.accumulation)

# outputs:
outputs = {
"rgb_coarse": coarse_renderer_outputs.rgb,
"rgb_fine": fine_renderer_outputs.rgb,
"accumulation_coarse": coarse_renderer_accumulation,
"accumulation_fine": fine_renderer_accumulation,
"depth_coarse": coarse_renderer_depth,
"depth_fine": fine_renderer_depth,
"accumulation_coarse": coarse_renderer_accumulation.accumulation,
"accumulation_fine": fine_renderer_accumulation.accumulation,
"depth_coarse": coarse_renderer_depth.depth,
"depth_fine": fine_renderer_depth.depth,
}
return outputs

Expand All @@ -203,43 +193,56 @@ def get_loss_dict(self, outputs, batch):
def log_test_image_outputs(self, image_idx, step, image, outputs):
rgb_coarse = outputs["rgb_coarse"]
rgb_fine = outputs["rgb_fine"]
accumulation_coarse = outputs["accumulation_coarse"]
accumulation_fine = outputs["accumulation_fine"]
depth_coarse = outputs["depth_coarse"]
depth_fine = outputs["depth_fine"]

combined_image = torch.cat([image, rgb_coarse, rgb_fine], dim=1)
writer.write_event(
{"name": f"image_idx_{image_idx}-rgb_coarse_fine", "x": combined_image, "step": step, "group": "val_img"}
)

accumulation_coarse = visualization.apply_colormap(outputs["accumulation_coarse"])
accumulation_fine = visualization.apply_colormap(outputs["accumulation_fine"])
combined_image = torch.cat([accumulation_coarse, accumulation_fine], dim=1)
writer.write_event(
{"name": f"image_idx_{image_idx}", "x": combined_image, "step": step, "group": "val_accumulation"}
)

depth_coarse = visualization.apply_depth_colormap(
outputs["depth_coarse"],
accumulation=outputs["accumulation_coarse"],
near_plane=self.near_plane,
far_plane=self.far_plane,
)
depth_fine = visualization.apply_depth_colormap(
outputs["depth_fine"],
accumulation=outputs["accumulation_fine"],
near_plane=self.near_plane,
far_plane=self.far_plane,
)
combined_image = torch.cat([depth_coarse, depth_fine], dim=1)
writer.write_event({"name": f"image_idx_{image_idx}", "x": combined_image, "step": step, "group": "val_depth"})

coarse_psnr = metrics.get_psnr(image, rgb_coarse)
writer.write_event(
{"name": f"image_idx_{image_idx}", "scalar": float(coarse_psnr), "step": step, "group": "val"}
)
# Switch images from [H, W, C] to [1, C, H, W] for metrics computations
image = torch.moveaxis(image, -1, 0)[None, ...]
rgb_coarse = torch.moveaxis(rgb_coarse, -1, 0)[None, ...]
rgb_fine = torch.moveaxis(rgb_fine, -1, 0)[None, ...]

coarse_ssim = metrics.get_ssim(image, rgb_coarse)
coarse_psnr = self.psnr(image, rgb_coarse)
writer.write_event(
{"name": f"image_idx_{image_idx}", "scalar": float(coarse_ssim), "step": step, "group": "ssim"}
{"name": f"val_{image_idx}-coarse", "scalar": float(coarse_psnr), "step": step, "group": "psnr"}
)

fine_psnr = metrics.get_psnr(image, rgb_fine)
fine_psnr = self.psnr(image, rgb_fine)
stats_tracker.update_stats(
{"name": stats_tracker.Stats.CURR_TEST_PSNR, "value": float(fine_psnr), "step": step}
)
writer.write_event(
{"name": f"image_idx_{image_idx}-fine_psnr", "scalar": float(fine_psnr), "step": step, "group": "val"}
{"name": f"val_idx_{image_idx}-fine", "scalar": float(fine_psnr), "step": step, "group": "psnr"}
)

fine_ssim = metrics.get_ssim(image, rgb_fine)
fine_ssim = self.ssim(image, rgb_fine)
writer.write_event({"name": f"val_idx_{image_idx}", "scalar": float(fine_ssim), "step": step, "group": "ssim"})

fine_lpips = self.lpips(image, rgb_fine)
writer.write_event(
{"name": f"image_idx_{image_idx}-fine_ssim", "scalar": float(fine_ssim), "step": step, "group": "ssim"}
{"name": f"val_idx_{image_idx}", "scalar": float(fine_lpips), "step": step, "group": "lpips"}
)
63 changes: 0 additions & 63 deletions mattport/nerf/metrics.py

This file was deleted.

0 comments on commit 2ebe34f

Please sign in to comment.