Skip to content

Commit

Permalink
Better align with nerf
Browse files Browse the repository at this point in the history
  • Loading branch information
tancik committed May 10, 2022
1 parent 5fb3d14 commit 8c1e9d7
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 5 deletions.
2 changes: 1 addition & 1 deletion configs/graph/default_graph.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ resume_train:
# NOTE(ethan): currently setting to nothing b/c we're still designing the config format...
network:
_target_: mattport.nerf.graph.vanilla_nerf.NeRFGraph
near_plane: 1.0
near_plane: 2.0
far_plane: 6.0
num_coarse_samples: 64
num_importance_samples: 128
Expand Down
2 changes: 1 addition & 1 deletion mattport/nerf/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def get_dataset_inputs_dict(
dataset_type: str,
downscale_factor: int = 1,
alpha_color: Optional[Union[str, list, ListConfig]] = None,
splits: Tuple[str] = ("train", "val"),
splits: Tuple[str] = ("train", "val", "test"),
) -> Dict[str, DatasetInputs]:
"""Returns the dataset inputs, which will be used with an ImageDataset and RayGenerator.
# TODO: implement the `test` split, which will have depths and normals, etc.
Expand Down
5 changes: 3 additions & 2 deletions mattport/nerf/graph/vanilla_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def __init__(self, num_layers=8, layer_width=256, skip_connections: Tuple = (4,)
self.mlp_rgb = MLP(
in_dim=self.mlp_base.get_out_dim() + self.encoding_dir.get_out_dim(),
out_dim=self.layer_width // 2,
num_layers=1,
num_layers=2,
layer_width=self.layer_width // 2,
activation=nn.ReLU(),
)
self.field_output_rgb = RGBFieldHead(in_dim=self.mlp_rgb.get_out_dim())
Expand Down Expand Up @@ -84,7 +85,7 @@ def __init__(
self,
intrinsics=None,
camera_to_world=None,
near_plane=1.0,
near_plane=2.0,
far_plane=6.0,
num_coarse_samples=64,
num_importance_samples=128,
Expand Down
2 changes: 1 addition & 1 deletion mattport/nerf/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def forward(
u = torch.rand(size=(*cdf.shape[:-1], num_samples), device=cdf.device)
else:
u = torch.linspace(0.0, 1.0, steps=num_samples, device=cdf.device)
u = torch.expand(size=(*cdf.shape[:-1], num_samples))
u = u.expand(size=(*cdf.shape[:-1], num_samples))

u = u.contiguous()
indicies = torch.searchsorted(cdf, u, right=True)
Expand Down
2 changes: 2 additions & 0 deletions mattport/nerf/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,10 @@ def train_iteration(self, batch: dict, step: int):
@profiler.time_function
def test_image(self, image_idx, step):
"""Test a specific image."""
self.graph.eval()
intrinsics = self.val_image_intrinsics[image_idx]
camera_to_world = self.val_image_camera_to_world[image_idx]
outputs = self.graph.get_outputs_for_camera(intrinsics, camera_to_world)
image = self.val_image_dataset[image_idx]["image"].to(self.device)
self.graph.log_test_image_outputs(image_idx, step, image, outputs)
self.graph.train()

0 comments on commit 8c1e9d7

Please sign in to comment.