diff --git a/eval_nerf.py b/eval_nerf.py index f2ab125..ce1738f 100644 --- a/eval_nerf.py +++ b/eval_nerf.py @@ -10,10 +10,10 @@ from tqdm import tqdm from nerf import (CfgNode, get_ray_bundle, load_blender_data, load_llff_data, - models, positional_encoding, run_one_iter_of_nerf) + models, get_embedding_function, run_one_iter_of_nerf) -def cast_to_image(tensor): +def cast_to_image(tensor, dataset_type): # Input tensor is (H, W, 3). Convert to (3, H, W). tensor = tensor.permute(2, 0, 1) # Convert to PIL Image and then np.array (output shape: (H, W, 3)) @@ -68,25 +68,25 @@ def main(): hwf = [int(H), int(W), focal] render_poses = torch.from_numpy(render_poses) - def encode_position_fn(x): - return positional_encoding( - x, - num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz, - include_input=cfg.models.coarse.include_input_xyz, - ) - - def encode_direction_fn(x): - return positional_encoding( - x, - num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir, - include_input=cfg.models.coarse.include_input_dir, - ) - # Device on which to run. device = "cpu" if torch.cuda.is_available(): device = "cuda" + encode_position_fn = get_embedding_function( + num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz, + include_input=cfg.models.coarse.include_input_xyz, + log_sampling=cfg.models.coarse.log_sampling_xyz + ) + + encode_direction_fn = None + if cfg.models.coarse.use_viewdirs: + encode_direction_fn = get_embedding_function( + num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir, + include_input=cfg.models.coarse.include_input_dir, + log_sampling=cfg.models.coarse.log_sampling_dir, + ) + # Initialize a coarse resolution model. model_coarse = getattr(models, cfg.models.coarse.type)( num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz, @@ -119,6 +119,12 @@ def encode_direction_fn(x): "The checkpoint has a fine-level model, but it could " "not be loaded (possibly due to a mismatched config file." ) + if "height" in checkpoint.keys(): + hwf[0] = checkpoint["height"] + if "width" in checkpoint.keys(): + hwf[1] = checkpoint["width"] + if "focal_length" in checkpoint.keys(): + hwf[2] = checkpoint["focal_length"] model_coarse.eval() if model_fine: @@ -135,6 +141,7 @@ def encode_direction_fn(x): start = time.time() rgb = None, None with torch.no_grad(): + pose = pose[:3, :4] ray_origins, ray_directions = get_ray_bundle(hwf[0], hwf[1], hwf[2], pose) rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf( hwf[0], @@ -153,7 +160,7 @@ def encode_direction_fn(x): times_per_image.append(time.time() - start) if configargs.savedir: savefile = os.path.join(configargs.savedir, f"{i:04d}.png") - imageio.imwrite(savefile, cast_to_image(rgb[..., :3])) + imageio.imwrite(savefile, cast_to_image(rgb[..., :3], cfg.dataset.type.lower())) tqdm.write(f"Avg time per image: {sum(times_per_image) / (i + 1)}") diff --git a/nerf/train_utils.py b/nerf/train_utils.py index e92ae15..1ebb03b 100644 --- a/nerf/train_utils.py +++ b/nerf/train_utils.py @@ -25,10 +25,6 @@ def run_network(network_fn, pts, ray_batch, chunksize, embed_fn, embeddirs_fn): return radiance_field -def identity_encoding(x): - return x - - def predict_and_render_radiance( ray_batch, model_coarse, @@ -39,11 +35,6 @@ def predict_and_render_radiance( encode_direction_fn=None, ): # TESTED - if encode_position_fn is None: - encode_position_fn = identity_encoding - if encode_direction_fn is None: - encode_direction_fn = identity_encoding - num_rays = ray_batch.shape[0] ro, rd = ray_batch[..., :3], ray_batch[..., 3:6] bounds = ray_batch[..., 6:8].view((-1, 1, 2)) @@ -98,7 +89,6 @@ def predict_and_render_radiance( white_background=getattr(options.nerf, mode).white_background, ) - # TODO: Implement importance sampling, and finer network. rgb_fine, disp_fine, acc_fine = None, None, None if getattr(options.nerf, mode).num_fine > 0: # rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map @@ -150,13 +140,6 @@ def run_one_iter_of_nerf( encode_position_fn=None, encode_direction_fn=None, ): - if encode_position_fn is None: - encode_position_fn = identity_encoding - if encode_direction_fn is None: - encode_direction_fn = identity_encoding - - # ray_origins = batch_rays[0] - # ray_directions = batch_rays[1] viewdirs = None if options.nerf.use_viewdirs: # Provide ray directions as input @@ -197,11 +180,20 @@ def run_one_iter_of_nerf( for batch in batches ] synthesized_images = list(zip(*pred)) - synthesized_images = [torch.cat(image, dim=0) for image in synthesized_images] + synthesized_images = [torch.cat(image, dim=0) if image[0] is not None else (None) for image in synthesized_images] if mode == "validation": synthesized_images = [ - image.view(shape) + image.view(shape) if image is not None else None for (image, shape) in zip(synthesized_images, restore_shapes) ] - # Returns rgb_coarse, disp_coarse, acc_coarse, rgb_fine, disp_fine, acc_fine. + + # Returns rgb_coarse, disp_coarse, acc_coarse, rgb_fine, disp_fine, acc_fine + # (assuming both the coarse and fine networks are used). + if model_fine: + return tuple(synthesized_images) + else: + # If the fine network is not used, rgb_fine, disp_fine, acc_fine are + # set to None. + return tuple(synthesized_images + [None, None, None]) + return tuple(synthesized_images)