Skip to content

Commit

Permalink
Typecheck workspace (autonomousvision#231)
Browse files Browse the repository at this point in the history
  • Loading branch information
tancik authored Jul 30, 2022
1 parent db9d830 commit 6c45f72
Show file tree
Hide file tree
Showing 19 changed files with 151 additions and 90 deletions.
7 changes: 4 additions & 3 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
"[python]": {
"editor.defaultFormatter": "ms-python.python",
"editor.codeActionsOnSave": {
"source.organizeImports": true
},
"source.organizeImports": true
}
},
"editor.formatOnSave": true,
"python.envFile": "${workspaceFolder}/.env",
Expand Down Expand Up @@ -94,5 +94,6 @@
"ios": "cpp",
"__atomic": "cpp",
"__node_handle": "cpp"
}
},
"python.analysis.typeCheckingMode": "basic"
}
3 changes: 2 additions & 1 deletion nerfactory/cameras/rays.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ class RayBundle(TensorDataclass):
nears: Distance along ray to start sampling
fars: Rays Distance along ray to stop sampling
valid_mask: Rays that are valid
num_rays_per_chunk: Number of rays per chunk
"""

origins: TensorType["num_rays", 3]
Expand All @@ -168,7 +169,7 @@ class RayBundle(TensorDataclass):
nears: Optional[TensorType["num_rays", 1]] = None
fars: Optional[TensorType["num_rays", 1]] = None
valid_mask: Optional[TensorType["num_rays", 1, bool]] = None
num_rays_per_chunk: Optional[int] = None
num_rays_per_chunk: int = 128

def set_camera_indices(self, camera_index: int) -> None:
"""Sets all of the the camera indices to a specific camera index.
Expand Down
4 changes: 2 additions & 2 deletions nerfactory/data/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ class Semantics:
stuff_classes: List[str]
thing_filenames: List[str]
thing_classes: List[str]
stuff_colors: Optional[torch.Tensor] = None
thing_colors: Optional[torch.Tensor] = None
stuff_colors: torch.Tensor
thing_colors: torch.Tensor


@dataclass
Expand Down
20 changes: 11 additions & 9 deletions nerfactory/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import functools
import logging
import os
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -69,7 +69,7 @@ def __init__(self, config: Config, local_rank: int = 0, world_size: int = 1):
self.dataloader_train: TrainDataloader
self.dataloader_eval: EvalDataloader
# model variables
self.graph = Graph
self.graph: Union[Graph, DDP]
self.optimizers: Optimizers
self.start_step = 0
# logging variables
Expand All @@ -90,6 +90,8 @@ def setup(self, test_mode=False):
self.dataset_inputs_train, self.dataloader_train = setup_dataset_train(self.config.data, device=self.device)
_, self.dataloader_eval = setup_dataset_eval(self.config.data, test_mode=test_mode, device=self.device)
self.graph = setup_graph(self.config.graph, self.dataset_inputs_train, device=self.device)
if not isinstance(self.graph, Graph):
raise ValueError("Graph was improperly initialized.")
self.optimizers = setup_optimizers(self.config.optimizers, self.graph.get_param_groups())

self._load_checkpoint()
Expand Down Expand Up @@ -179,9 +181,9 @@ def _save_checkpoint(self, output_dir: str, step: int) -> None:
os.makedirs(output_dir)
ckpt_path = os.path.join(output_dir, f"step-{step:09d}.ckpt")
if hasattr(self.graph, "module"):
model = self.graph.module.state_dict()
model = self.graph.module.state_dict() # type: ignore
else:
model = self.graph.state_dict()
model = self.graph.state_dict() # type: ignore
torch.save(
{
"step": step,
Expand All @@ -207,7 +209,7 @@ def train_iteration(self, ray_indices: TensorType["num_rays", 3], batch: dict, s
self.optimizers.zero_grad_all()
with torch.autocast(device_type=ray_indices.device.type, enabled=self.mixed_precision):
_, loss_dict, metrics_dict = self.graph.forward(ray_indices=ray_indices, batch=batch) # type: ignore
loss = sum(loss_dict.values())
loss = sum(loss_dict.values()) # type: ignore
self.grad_scaler.scale(loss).backward() # type: ignore
self.optimizers.optimizer_scaler_step_all(self.grad_scaler)
self.grad_scaler.update()
Expand All @@ -217,9 +219,9 @@ def train_iteration(self, ray_indices: TensorType["num_rays", 3], batch: dict, s
callback.after_step(step)

# Merging loss and metrics dict into a single output.
loss_dict["loss"] = loss
loss_dict.update(metrics_dict)
return loss_dict
loss_dict["loss"] = loss # type: ignore
loss_dict.update(metrics_dict) # type: ignore
return loss_dict # type: ignore

@profiler.time_function
def test_image(self, camera_ray_bundle: RayBundle, batch: dict, step: Optional[int] = None) -> float:
Expand All @@ -240,7 +242,7 @@ def test_image(self, camera_ray_bundle: RayBundle, batch: dict, step: Optional[i
outputs = self.graph.get_outputs_for_camera_ray_bundle(camera_ray_bundle) # type: ignore
psnr = self.graph.log_test_image_outputs(image_idx, step, batch, outputs) # type: ignore
self.graph.train() # type: ignore
return psnr
return psnr # type: ignore

def eval_with_dataloader(self, dataloader: EvalDataloader, step: Optional[int] = None) -> None:
"""Run evaluation with a given dataloader.
Expand Down
11 changes: 7 additions & 4 deletions nerfactory/fields/density_fields/density_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""
Code to implement the density grid.
"""
from typing import Callable, List, NoReturn, Tuple
from typing import Callable, List, Tuple, Union

import torch
from torch import nn
Expand All @@ -25,7 +25,7 @@
from nerfactory.cameras.rays import RaySamples


def _create_grid_coords(resolution: int, device: torch.device = "cpu") -> TensorType["n_coords", 3]:
def _create_grid_coords(resolution: int, device: Union[torch.device, str] = "cpu") -> TensorType["n_coords", 3]:
"""Create 3D grid coordinates
Args:
Expand Down Expand Up @@ -77,6 +77,9 @@ class DensityGrid(nn.Module):
update_every_num_iters (int): How frequently to update the grid values. Defaults to 16.
"""

density_grid: TensorType["num_cascades", "resolution**3"]
mean_density: TensorType[1]

def __init__(
self,
center: float = 0.0,
Expand Down Expand Up @@ -155,7 +158,7 @@ def update_density_grid(
step: int,
density_threshold: float = 2, # 0.01 / (SQRT3 / 1024 * 3)
decay: float = 0.95,
) -> NoReturn:
) -> None:
"""Update the density grid in EMA way.
Args:
Expand Down Expand Up @@ -197,7 +200,7 @@ def update_density_grid(

# pack to bitfield
self.density_bitfield.data = nerfactory_cuda.packbits(
self.density_grid, min(self.mean_density.item(), density_threshold)
self.density_grid, min(self.mean_density.item(), density_threshold) # type: ignore
)

# TODO(ruilongli): max pooling? https://github.com/NVlabs/instant-ngp/blob/master/src/testbed_nerf.cu#L578
Expand Down
7 changes: 0 additions & 7 deletions nerfactory/fields/instant_ngp_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,5 @@ def __init__(
)
self.aabb = Parameter(aabb, requires_grad=False)

def get_density(self, ray_samples: RaySamples):
normalized_ray_samples = ray_samples
normalized_ray_samples.positions = SceneBounds.get_normalized_positions(
normalized_ray_samples.frustums.get_positions(), self.aabb
)
return super().get_density(normalized_ray_samples)


field_implementation_to_class = {"tcnn": TCNNInstantNGPField, "torch": TorchInstantNGPField}
4 changes: 2 additions & 2 deletions nerfactory/fields/nerf_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(
self.field_output_density = DensityFieldHead(in_dim=self.mlp_base.get_out_dim())
self.field_heads = nn.ModuleList(field_heads)
for field_head in self.field_heads:
field_head.set_in_dim(self.mlp_head.get_out_dim())
field_head.set_in_dim(self.mlp_head.get_out_dim()) # type: ignore

def get_density(self, ray_samples: RaySamples):
if self.use_integrated_encoding:
Expand All @@ -107,6 +107,6 @@ def get_outputs(
outputs = {}
for field_head in self.field_heads:
encoded_dir = self.direction_encoding(ray_samples.frustums.directions)
mlp_out = self.mlp_head(torch.cat([encoded_dir, density_embedding], dim=-1))
mlp_out = self.mlp_head(torch.cat([encoded_dir, density_embedding], dim=-1)) # type: ignore
outputs[field_head.field_head_name] = field_head(mlp_out)
return outputs
8 changes: 6 additions & 2 deletions nerfactory/fields/nerfw_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,15 @@ def get_outputs(
"""
outputs = {}
encoded_dir = self.direction_encoding(ray_samples.frustums.directions)
if ray_samples.camera_indices is None:
raise AttributeError("Camera indices are not provided.")
embedded_appearance = self.embedding_appearance(ray_samples.camera_indices.squeeze())
mlp_head_out = self.mlp_head(torch.cat([density_embedding, encoded_dir, embedded_appearance], dim=-1))
mlp_in = torch.cat([density_embedding, encoded_dir, embedded_appearance], dim=-1) # type: ignore
mlp_head_out = self.mlp_head(mlp_in)
outputs[self.field_head_rgb.field_head_name] = self.field_head_rgb(mlp_head_out) # static rgb
embedded_transient = self.embedding_transient(ray_samples.camera_indices.squeeze())
transient_mlp_out = self.mlp_transient(torch.cat([density_embedding, embedded_transient], dim=-1))
transient_mlp_in = torch.cat([density_embedding, embedded_transient], dim=-1) # type: ignore
transient_mlp_out = self.mlp_transient(transient_mlp_in)
outputs[self.field_head_transient_uncertainty.field_head_name] = self.field_head_transient_uncertainty(
transient_mlp_out
) # uncertainty
Expand Down
43 changes: 27 additions & 16 deletions nerfactory/graphs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""
from abc import abstractmethod
from collections import defaultdict
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union

import torch
from omegaconf import DictConfig
Expand Down Expand Up @@ -53,7 +53,9 @@ def device(self):
return self.device_indicator_param.device

@abstractmethod
def forward(self, ray_indices: TensorType["num_rays", 3], batch: Union[str, Dict[str, torch.tensor]] = None):
def forward(
self, ray_indices: TensorType["num_rays", 3], batch: Union[str, Optional[Dict[str, torch.Tensor]]] = None
):
"""Process starting with ray indices. Turns them into rays, then performs volume rendering."""


Expand All @@ -75,14 +77,14 @@ class Graph(AbstractGraph):

def __init__(
self,
intrinsics: torch.Tensor = None,
camera_to_world: torch.Tensor = None,
loss_coefficients: DictConfig = None,
scene_bounds: SceneBounds = None,
intrinsics: torch.Tensor,
camera_to_world: torch.Tensor,
loss_coefficients: DictConfig,
scene_bounds: Optional[SceneBounds] = None,
enable_collider: bool = True,
collider_config: DictConfig = None,
collider_config: Optional[DictConfig] = None,
enable_density_field: bool = False,
density_field_config: DictConfig = None,
density_field_config: Optional[DictConfig] = None,
**kwargs,
) -> None:
super().__init__()
Expand Down Expand Up @@ -122,6 +124,10 @@ def populate_collider(self):
if self.enable_collider:
self.collider = instantiate_from_dict_config(self.collider_config, scene_bounds=self.scene_bounds)

@abstractmethod
def populate_fields(self):
"""Set the fields."""

@abstractmethod
def populate_misc_modules(self):
"""Initializes any additional modules that are part of the network."""
Expand All @@ -135,7 +141,7 @@ def get_param_groups(self) -> Dict[str, List[Parameter]]:
"""

@abstractmethod
def get_outputs(self, ray_bundle: RayBundle) -> dict:
def get_outputs(self, ray_bundle: RayBundle) -> Dict[str, torch.Tensor]:
"""Takes in a Ray Bundle and returns a dictionary of outputs.
Args:
Expand All @@ -153,7 +159,9 @@ def process_outputs_as_images(self, outputs): # pylint:disable=no-self-use
v = torch.tile(v, (1, 1, 3))
outputs[k] = v

def forward_after_ray_generator(self, ray_bundle: RayBundle, batch: Union[str, Dict[str, torch.tensor]] = None):
def forward_after_ray_generator(
self, ray_bundle: RayBundle, batch: Optional[Union[str, Dict[str, torch.Tensor]]] = None
):
"""Run forward starting with a ray bundle."""
if self.collider is not None:
intersected_ray_bundle = self.collider(ray_bundle)
Expand Down Expand Up @@ -185,19 +193,21 @@ def forward_after_ray_generator(self, ray_bundle: RayBundle, batch: Union[str, D
loss_dict[loss_name] *= self.loss_coefficients[loss_name]
return outputs, loss_dict, metrics_dict

def forward(self, ray_indices: TensorType["num_rays", 3], batch: Union[str, Dict[str, torch.tensor]] = None):
def forward(
self, ray_indices: TensorType["num_rays", 3], batch: Optional[Union[str, Dict[str, torch.Tensor]]] = None
):
"""Run the forward starting with ray indices."""
ray_bundle = self.ray_generator.forward(ray_indices)
return self.forward_after_ray_generator(ray_bundle, batch=batch)

def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.tensor]:
def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.Tensor]:
"""Compute and returns metrics."""
# pylint: disable=unused-argument
# pylint: disable=no-self-use
return {}

@abstractmethod
def get_loss_dict(self, outputs, batch, metrics_dict, loss_coefficients) -> Dict[str, torch.tensor]:
def get_loss_dict(self, outputs, batch, metrics_dict, loss_coefficients) -> Dict[str, torch.Tensor]:
"""Computes and returns the losses dict."""

@torch.no_grad()
Expand All @@ -213,10 +223,10 @@ def get_outputs_for_camera_ray_bundle(self, camera_ray_bundle: RayBundle):
end_idx = i + camera_ray_bundle.num_rays_per_chunk
ray_bundle = camera_ray_bundle.get_row_major_sliced_ray_bundle(start_idx, end_idx)
outputs = self.forward_after_ray_generator(ray_bundle)
for output_name, output in outputs.items():
for output_name, output in outputs.items(): # type: ignore
outputs_lists[output_name].append(output)
for output_name, outputs_list in outputs_lists.items():
outputs[output_name] = torch.cat(outputs_list).view(image_height, image_width, -1)
outputs[output_name] = torch.cat(outputs_list).view(image_height, image_width, -1) # type: ignore
return outputs

def get_outputs_for_camera(self, camera: Camera):
Expand All @@ -230,7 +240,8 @@ def log_test_image_outputs(self, image_idx, step, batch, outputs):

def load_graph(self, loaded_state: Dict[str, Any]) -> None:
"""Load the checkpoint from the given path"""
self.load_state_dict({key.replace("module.", ""): value for key, value in loaded_state["model"].items()})
state = {key.replace("module.", ""): value for key, value in loaded_state["model"].items()}
self.load_state_dict(state) # type: ignore


@profiler.time_function
Expand Down
10 changes: 8 additions & 2 deletions nerfactory/graphs/instant_ngp.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,15 @@ def get_training_callbacks(self) -> List[Callback]:
Callback(
update_every_num_iters=self.density_field.update_every_num_iters,
func=self.density_field.update_density_grid,
density_eval_func=self.field.density_fn,
density_eval_func=self.field.density_fn, # type: ignore
)
]

def populate_fields(self):
"""Set the fields."""
# torch or tiny-cuda-nn version
if self.scene_bounds is None:
raise ValueError("scene_bounds must be set to use an NGPGraph")
self.field = field_implementation_to_class[self.field_implementation](self.scene_bounds.aabb)

def populate_misc_modules(self):
Expand All @@ -78,6 +80,8 @@ def populate_misc_modules(self):

def get_param_groups(self) -> Dict[str, List[Parameter]]:
param_groups = {}
if self.field is None:
raise ValueError("populate_fields() must be called before get_param_groups")
param_groups["fields"] = list(self.field.parameters())
return param_groups

Expand All @@ -89,6 +93,8 @@ def get_outputs(self, ray_bundle: RayBundle):
num_rays = len(ray_bundle)
device = ray_bundle.origins.device

if self.field is None:
raise ValueError("populate_fields() must be called before get_outputs")
ray_samples, packed_info, t_min, t_max = self.sampler(ray_bundle, self.field.aabb)

field_outputs = self.field.forward(ray_samples)
Expand Down Expand Up @@ -170,7 +176,7 @@ def log_test_image_outputs(self, image_idx, step, batch, outputs):
lpips = self.lpips(image, rgb)

writer.put_scalar(name=f"psnr/val_{image_idx}-fine", scalar=float(psnr), step=step)
writer.put_scalar(name=f"ssim/val_{image_idx}", scalar=float(ssim), step=step)
writer.put_scalar(name=f"ssim/val_{image_idx}", scalar=float(ssim), step=step) # type: ignore
writer.put_scalar(name=f"lpips/val_{image_idx}", scalar=float(lpips), step=step)

writer.put_scalar(name=writer.EventName.CURR_TEST_PSNR, scalar=float(psnr), step=step)
Expand Down
Loading

0 comments on commit 6c45f72

Please sign in to comment.