Skip to content

Commit

Permalink
Improving helper functions
Browse files Browse the repository at this point in the history
  • Loading branch information
vishakhhegde committed Sep 27, 2022
1 parent f1aee5a commit d5975c0
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 14 deletions.
41 changes: 31 additions & 10 deletions chap6/train_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,20 @@
EmissionAbsorptionRaymarcher,
ImplicitRenderer,
)

if torch.cuda.is_available():
device = torch.device("cuda:0")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")

from utils.plot_image_grid import image_grid
from utils.generate_cow_renders import generate_cow_renders

from utils.helper_functions import (generate_rotating_nerf,
huber,
show_full_render,
sample_images_at_mc_locs)
from nerf_model import NeuralRadianceField

if torch.cuda.is_available():
device = torch.device("cuda:0")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")

target_cameras, target_images, target_silhouettes = generate_cow_renders(num_views=40, azimuth_range=180)
print(f'Generated {len(target_images)} images/silhouettes/cameras.')

Expand Down Expand Up @@ -69,8 +68,8 @@

lr = 1e-3
optimizer = torch.optim.Adam(neural_radiance_field.parameters(), lr=lr)
batch_size = 6
n_iter = 3000
batch_size = 3
n_iter = 500

loss_history_color, loss_history_sil = [], []
for iteration in range(n_iter):
Expand Down Expand Up @@ -127,6 +126,28 @@
loss.backward()
optimizer.step()

# Visualize the full renders every 100 iterations.
if iteration % 100 == 0:
show_idx = torch.randperm(len(target_cameras))[:1]
fig = show_full_render(
neural_radiance_field,
FoVPerspectiveCameras(
R = target_cameras.R[show_idx],
T = target_cameras.T[show_idx],
znear = target_cameras.znear[show_idx],
zfar = target_cameras.zfar[show_idx],
aspect_ratio = target_cameras.aspect_ratio[show_idx],
fov = target_cameras.fov[show_idx],
device = device,
),
target_images[show_idx][0],
target_silhouettes[show_idx][0],
renderer_grid,
loss_history_color,
loss_history_sil,
)
fig.savefig(f'intermediate_{iteration}')

with torch.no_grad():
rotating_nerf_frames = generate_rotating_nerf(neural_radiance_field, target_cameras, renderer_grid, n_frames=3*5, device=device)

Expand Down
8 changes: 4 additions & 4 deletions chap6/utils/helper_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from pytorch3d.transforms import so3_exp_map
import matplotlib.pyplot as plt
from tqdm import tqdm
from pytorch3d.renderer import (
FoVPerspectiveCameras,
Expand Down Expand Up @@ -52,7 +53,8 @@ def sample_images_at_mc_locs(target_images, sampled_rays_xy):
def show_full_render(
neural_radiance_field, camera,
target_image, target_silhouette,
loss_history_color, loss_history_sil,
renderer_grid, loss_history_color,
loss_history_sil,
):
"""
This is a helper function for visualizing the
Expand Down Expand Up @@ -103,9 +105,7 @@ def show_full_render(
ax_.grid("off")
ax_.axis("off")
ax_.set_title(title_)
fig.canvas.draw(); fig.show()
display.clear_output(wait=True)
display.display(fig)
fig.canvas.draw()
return fig

def generate_rotating_nerf(neural_radiance_field, target_cameras, renderer_grid, n_frames = 50, device=torch.device("cpu")):
Expand Down

0 comments on commit d5975c0

Please sign in to comment.