Skip to content

Commit

Permalink
add saving images to eval script
Browse files Browse the repository at this point in the history
  • Loading branch information
niujinshuchong committed Mar 11, 2023
1 parent 079ee49 commit af1ae1a
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 4 deletions.
2 changes: 1 addition & 1 deletion nerfstudio/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,5 +358,5 @@ def eval_iteration(self, step):

# all eval images
if step_check(step, self.config.trainer.steps_per_eval_all_images):
metrics_dict = self.pipeline.get_average_eval_image_metrics(step=step)
metrics_dict, _ = self.pipeline.get_average_eval_image_metrics(step=step)
writer.put_dict(name="Eval Images Metrics Dict (all images)", scalar_dict=metrics_dict, step=step)
6 changes: 4 additions & 2 deletions nerfstudio/pipelines/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def get_average_eval_image_metrics(self, step: Optional[int] = None):
"""
self.eval()
metrics_dict_list = []
images_dict_list = []
num_images = len(self.datamanager.fixed_indices_eval_dataloader)
with Progress(
TextColumn("[progress.description]{task.description}"),
Expand All @@ -347,13 +348,14 @@ def get_average_eval_image_metrics(self, step: Optional[int] = None):
height, width = camera_ray_bundle.shape
num_rays = height * width
outputs = self.model.get_outputs_for_camera_ray_bundle(camera_ray_bundle)
metrics_dict, _ = self.model.get_image_metrics_and_images(outputs, batch)
metrics_dict, images_dict = self.model.get_image_metrics_and_images(outputs, batch)
assert "num_rays_per_sec" not in metrics_dict
metrics_dict["num_rays_per_sec"] = num_rays / (time() - inner_start)
fps_str = "fps"
assert fps_str not in metrics_dict
metrics_dict[fps_str] = metrics_dict["num_rays_per_sec"] / (height * width)
metrics_dict_list.append(metrics_dict)
images_dict_list.append(images_dict)
progress.advance(task)
# average the metrics list
metrics_dict = {}
Expand All @@ -362,7 +364,7 @@ def get_average_eval_image_metrics(self, step: Optional[int] = None):
torch.mean(torch.tensor([metrics_dict[key] for metrics_dict in metrics_dict_list]))
)
self.train()
return metrics_dict
return metrics_dict, images_dict_list

@profiler.time_function
def get_visibility_mask(
Expand Down
21 changes: 20 additions & 1 deletion scripts/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,20 @@
from dataclasses import dataclass
from pathlib import Path

import cv2
import numpy as np
import torch
import tyro
from rich.console import Console

from nerfstudio.utils.eval_utils import eval_setup

CONSOLE = Console(width=120)

# speedup for when input size to model doesn't change (much)
torch.backends.cudnn.benchmark = True # type: ignore
torch.set_float32_matmul_precision("high")


@dataclass
class ComputePSNR:
Expand All @@ -24,13 +31,17 @@ class ComputePSNR:
load_config: Path
# Name of the output file.
output_path: Path = Path("output.json")
# Name of the output images dir.
output_images_path: Path = Path("output_images/")

def main(self) -> None:
"""Main function."""
config, pipeline, checkpoint_path = eval_setup(self.load_config)
assert self.output_path.suffix == ".json"
metrics_dict = pipeline.get_average_eval_image_metrics()
metrics_dict, images_dict_list = pipeline.get_average_eval_image_metrics()
self.output_path.parent.mkdir(parents=True, exist_ok=True)
self.output_images_path.mkdir(parents=True, exist_ok=True)

# Get the output and define the names to save to
benchmark_info = {
"experiment_name": config.experiment_name,
Expand All @@ -42,6 +53,14 @@ def main(self) -> None:
self.output_path.write_text(json.dumps(benchmark_info, indent=2), "utf8")
CONSOLE.print(f"Saved results to: {self.output_path}")

for idx, images_dict in enumerate(images_dict_list):
for k, v in images_dict.items():
cv2.imwrite(
str(self.output_images_path / Path(f"{k}_{idx}.png")),
(v.cpu().numpy() * 255.0).astype(np.uint8)[..., ::-1],
)
CONSOLE.print(f"Saved rendering results to: {self.output_images_path}")


def entrypoint():
"""Entrypoint for use with pyproject scripts."""
Expand Down

0 comments on commit af1ae1a

Please sign in to comment.