Skip to content

Commit

Permalink
Merge branch 'master' of github.com:krrish94/nerf-pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
krrish94 committed Apr 16, 2020
2 parents 6aa1b85 + 771ae31 commit 4e5b571
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 37 deletions.
41 changes: 24 additions & 17 deletions eval_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from tqdm import tqdm

from nerf import (CfgNode, get_ray_bundle, load_blender_data, load_llff_data,
models, positional_encoding, run_one_iter_of_nerf)
models, get_embedding_function, run_one_iter_of_nerf)


def cast_to_image(tensor):
def cast_to_image(tensor, dataset_type):
# Input tensor is (H, W, 3). Convert to (3, H, W).
tensor = tensor.permute(2, 0, 1)
# Convert to PIL Image and then np.array (output shape: (H, W, 3))
Expand Down Expand Up @@ -68,25 +68,25 @@ def main():
hwf = [int(H), int(W), focal]
render_poses = torch.from_numpy(render_poses)

def encode_position_fn(x):
return positional_encoding(
x,
num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz,
include_input=cfg.models.coarse.include_input_xyz,
)

def encode_direction_fn(x):
return positional_encoding(
x,
num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir,
include_input=cfg.models.coarse.include_input_dir,
)

# Device on which to run.
device = "cpu"
if torch.cuda.is_available():
device = "cuda"

encode_position_fn = get_embedding_function(
num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz,
include_input=cfg.models.coarse.include_input_xyz,
log_sampling=cfg.models.coarse.log_sampling_xyz
)

encode_direction_fn = None
if cfg.models.coarse.use_viewdirs:
encode_direction_fn = get_embedding_function(
num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir,
include_input=cfg.models.coarse.include_input_dir,
log_sampling=cfg.models.coarse.log_sampling_dir,
)

# Initialize a coarse resolution model.
model_coarse = getattr(models, cfg.models.coarse.type)(
num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz,
Expand Down Expand Up @@ -119,6 +119,12 @@ def encode_direction_fn(x):
"The checkpoint has a fine-level model, but it could "
"not be loaded (possibly due to a mismatched config file."
)
if "height" in checkpoint.keys():
hwf[0] = checkpoint["height"]
if "width" in checkpoint.keys():
hwf[1] = checkpoint["width"]
if "focal_length" in checkpoint.keys():
hwf[2] = checkpoint["focal_length"]

model_coarse.eval()
if model_fine:
Expand All @@ -135,6 +141,7 @@ def encode_direction_fn(x):
start = time.time()
rgb = None, None
with torch.no_grad():
pose = pose[:3, :4]
ray_origins, ray_directions = get_ray_bundle(hwf[0], hwf[1], hwf[2], pose)
rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf(
hwf[0],
Expand All @@ -153,7 +160,7 @@ def encode_direction_fn(x):
times_per_image.append(time.time() - start)
if configargs.savedir:
savefile = os.path.join(configargs.savedir, f"{i:04d}.png")
imageio.imwrite(savefile, cast_to_image(rgb[..., :3]))
imageio.imwrite(savefile, cast_to_image(rgb[..., :3], cfg.dataset.type.lower()))
tqdm.write(f"Avg time per image: {sum(times_per_image) / (i + 1)}")


Expand Down
32 changes: 12 additions & 20 deletions nerf/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@ def run_network(network_fn, pts, ray_batch, chunksize, embed_fn, embeddirs_fn):
return radiance_field


def identity_encoding(x):
return x


def predict_and_render_radiance(
ray_batch,
model_coarse,
Expand All @@ -39,11 +35,6 @@ def predict_and_render_radiance(
encode_direction_fn=None,
):
# TESTED
if encode_position_fn is None:
encode_position_fn = identity_encoding
if encode_direction_fn is None:
encode_direction_fn = identity_encoding

num_rays = ray_batch.shape[0]
ro, rd = ray_batch[..., :3], ray_batch[..., 3:6]
bounds = ray_batch[..., 6:8].view((-1, 1, 2))
Expand Down Expand Up @@ -98,7 +89,6 @@ def predict_and_render_radiance(
white_background=getattr(options.nerf, mode).white_background,
)

# TODO: Implement importance sampling, and finer network.
rgb_fine, disp_fine, acc_fine = None, None, None
if getattr(options.nerf, mode).num_fine > 0:
# rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map
Expand Down Expand Up @@ -150,13 +140,6 @@ def run_one_iter_of_nerf(
encode_position_fn=None,
encode_direction_fn=None,
):
if encode_position_fn is None:
encode_position_fn = identity_encoding
if encode_direction_fn is None:
encode_direction_fn = identity_encoding

# ray_origins = batch_rays[0]
# ray_directions = batch_rays[1]
viewdirs = None
if options.nerf.use_viewdirs:
# Provide ray directions as input
Expand Down Expand Up @@ -197,11 +180,20 @@ def run_one_iter_of_nerf(
for batch in batches
]
synthesized_images = list(zip(*pred))
synthesized_images = [torch.cat(image, dim=0) for image in synthesized_images]
synthesized_images = [torch.cat(image, dim=0) if image[0] is not None else (None) for image in synthesized_images]
if mode == "validation":
synthesized_images = [
image.view(shape)
image.view(shape) if image is not None else None
for (image, shape) in zip(synthesized_images, restore_shapes)
]
# Returns rgb_coarse, disp_coarse, acc_coarse, rgb_fine, disp_fine, acc_fine.

# Returns rgb_coarse, disp_coarse, acc_coarse, rgb_fine, disp_fine, acc_fine
# (assuming both the coarse and fine networks are used).
if model_fine:
return tuple(synthesized_images)
else:
# If the fine network is not used, rgb_fine, disp_fine, acc_fine are
# set to None.
return tuple(synthesized_images + [None, None, None])

return tuple(synthesized_images)

0 comments on commit 4e5b571

Please sign in to comment.