Skip to content

Commit

Permalink
Lint
Browse files Browse the repository at this point in the history
Signed-off-by: Krishna Murthy <[email protected]>
  • Loading branch information
Krishna Murthy committed Apr 17, 2020
1 parent 6189b96 commit a14357d
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 22 deletions.
17 changes: 13 additions & 4 deletions eval_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,15 @@
import yaml
from tqdm import tqdm

from nerf import (CfgNode, get_ray_bundle, load_blender_data, load_llff_data,
models, get_embedding_function, run_one_iter_of_nerf)
from nerf import (
CfgNode,
get_ray_bundle,
load_blender_data,
load_llff_data,
models,
get_embedding_function,
run_one_iter_of_nerf,
)


def cast_to_image(tensor, dataset_type):
Expand Down Expand Up @@ -85,7 +92,7 @@ def main():
encode_position_fn = get_embedding_function(
num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz,
include_input=cfg.models.coarse.include_input_xyz,
log_sampling=cfg.models.coarse.log_sampling_xyz
log_sampling=cfg.models.coarse.log_sampling_xyz,
)

encode_direction_fn = None
Expand Down Expand Up @@ -174,7 +181,9 @@ def main():
times_per_image.append(time.time() - start)
if configargs.savedir:
savefile = os.path.join(configargs.savedir, f"{i:04d}.png")
imageio.imwrite(savefile, cast_to_image(rgb[..., :3], cfg.dataset.type.lower()))
imageio.imwrite(
savefile, cast_to_image(rgb[..., :3], cfg.dataset.type.lower())
)
if configargs.save_disparity_image:
savefile = os.path.join(configargs.savedir, "disparity", f"{i:04d}.png")
imageio.imwrite(savefile, cast_to_disparity_image(disp))
Expand Down
4 changes: 2 additions & 2 deletions nerf/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def __init__(
use_viewdirs=True,
):
super(PaperNeRFModel, self).__init__()

include_input_xyz = 3 if include_input_xyz else 0
include_input_dir = 3 if include_input_dir else 0
self.dim_xyz = include_input_xyz + 2 * 3 * num_encoding_fn_xyz
Expand All @@ -161,7 +161,7 @@ def __init__(
self.relu = torch.nn.functional.relu

def forward(self, x):
xyz, dirs = x[..., :self.dim_xyz], x[..., self.dim_xyz:]
xyz, dirs = x[..., : self.dim_xyz], x[..., self.dim_xyz :]
for i in range(8):
if i == 4:
x = self.layers_xyz[i](torch.cat((xyz, x), -1))
Expand Down
16 changes: 11 additions & 5 deletions nerf/nerf_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,20 @@ def positional_encoding(
encoding = [tensor] if include_input else []
frequency_bands = None
if log_sampling:
frequency_bands = 2. ** torch.linspace(
0., num_encoding_functions - 1, num_encoding_functions,
dtype=tensor.dtype, device=tensor.device,
frequency_bands = 2.0 ** torch.linspace(
0.0,
num_encoding_functions - 1,
num_encoding_functions,
dtype=tensor.dtype,
device=tensor.device,
)
else:
frequency_bands = torch.linspace(
2. ** 0., 2. ** (num_encoding_functions - 1), num_encoding_functions,
dtype=tensor.dtype, device=tensor.device
2.0 ** 0.0,
2.0 ** (num_encoding_functions - 1),
num_encoding_functions,
dtype=tensor.dtype,
device=tensor.device,
)

for freq in frequency_bands:
Expand Down
7 changes: 5 additions & 2 deletions nerf/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,16 @@ def run_one_iter_of_nerf(
for batch in batches
]
synthesized_images = list(zip(*pred))
synthesized_images = [torch.cat(image, dim=0) if image[0] is not None else (None) for image in synthesized_images]
synthesized_images = [
torch.cat(image, dim=0) if image[0] is not None else (None)
for image in synthesized_images
]
if mode == "validation":
synthesized_images = [
image.view(shape) if image is not None else None
for (image, shape) in zip(synthesized_images, restore_shapes)
]

# Returns rgb_coarse, disp_coarse, acc_coarse, rgb_fine, disp_fine, acc_fine
# (assuming both the coarse and fine networks are used).
if model_fine:
Expand Down
3 changes: 1 addition & 2 deletions tiny_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import torch
from tqdm import tqdm, trange

from nerf import (cumprod_exclusive, get_minibatches, get_ray_bundle,
positional_encoding)
from nerf import cumprod_exclusive, get_minibatches, get_ray_bundle, positional_encoding


def compute_query_points_from_rays(
Expand Down
14 changes: 7 additions & 7 deletions train_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm, trange

from nerf import (CfgNode, get_ray_bundle, img2mse, load_blender_data,
load_llff_data, meshgrid_xy, models, mse2psnr,
get_embedding_function, run_one_iter_of_nerf)
from nerf import (CfgNode, get_embedding_function, get_ray_bundle, img2mse,
load_blender_data, load_llff_data, meshgrid_xy, models,
mse2psnr, run_one_iter_of_nerf)


def main():
Expand Down Expand Up @@ -63,7 +63,7 @@ def main():
H, W = int(H), int(W)
hwf = [H, W, focal]
if cfg.nerf.train.white_background:
images = images[..., :3] * images[..., -1:] + (1. - images[..., -1:])
images = images[..., :3] * images[..., -1:] + (1.0 - images[..., -1:])
elif cfg.dataset.type.lower() == "llff":
images, poses, bds, render_poses, i_test = load_llff_data(
cfg.dataset.basedir, factor=cfg.dataset.downsample_factor
Expand Down Expand Up @@ -104,7 +104,7 @@ def main():
include_input=cfg.models.coarse.include_input_xyz,
log_sampling=cfg.models.coarse.log_sampling_xyz,
)

encode_direction_fn = None
if cfg.models.coarse.use_viewdirs:
encode_direction_fn = get_embedding_function(
Expand Down Expand Up @@ -250,7 +250,7 @@ def main():
rgb_fine[..., :3], target_ray_values[..., :3]
)
# loss = torch.nn.functional.mse_loss(rgb_pred[..., :3], target_s[..., :3])
loss = 0.
loss = 0.0
# if fine_loss is not None:
# loss = fine_loss
# else:
Expand Down Expand Up @@ -337,7 +337,7 @@ def main():
)
target_ray_values = img_target
coarse_loss = img2mse(rgb_coarse[..., :3], target_ray_values[..., :3])
loss, fine_loss = 0., 0.
loss, fine_loss = 0.0, 0.0
if rgb_fine is not None:
fine_loss = img2mse(rgb_fine[..., :3], target_ray_values[..., :3])
loss = fine_loss
Expand Down

0 comments on commit a14357d

Please sign in to comment.