Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
lqz2 committed Nov 20, 2023
1 parent 35c2f51 commit 2c45307
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 135 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Initially taken from Github's Python gitignore file

ckpts
sam_pt
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
175 changes: 91 additions & 84 deletions NeuS/exp_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
import pdb
import math


def ranking_loss(error, penalize_ratio=0.7, type='mean'):
error, indices = torch.sort(error)
# only sum relatively small errors
s_error = torch.index_select(error, 0, index=indices[:int(penalize_ratio * indices.shape[0])])
s_error = torch.index_select(error, 0, index=indices[: int(penalize_ratio * indices.shape[0])])
if type == 'mean':
return torch.mean(s_error)
elif type == 'sum':
Expand All @@ -46,7 +47,10 @@ def __init__(self, conf_path, mode='train', case='CASE_NAME', is_continue=False,
os.makedirs(self.base_exp_dir, exist_ok=True)
self.dataset = Dataset(self.conf['dataset'])
self.dataloader = torch.utils.data.DataLoader(
self.dataset, batch_size=self.conf['train']['batch_size'], shuffle=True, num_workers=64,
self.dataset,
batch_size=self.conf['train']['batch_size'],
shuffle=True,
num_workers=64,
)
self.iter_step = 1

Expand Down Expand Up @@ -86,14 +90,13 @@ def __init__(self, conf_path, mode='train', case='CASE_NAME', is_continue=False,
params_to_train_slow += list(self.deviation_network.parameters())
# params_to_train += list(self.color_network.parameters())

self.optimizer = torch.optim.Adam([{'params': params_to_train_slow},
{'params': self.color_network.parameters(), 'lr': self.learning_rate * 2}], lr=self.learning_rate)
self.optimizer = torch.optim.Adam(
[{'params': params_to_train_slow}, {'params': self.color_network.parameters(), 'lr': self.learning_rate * 2}], lr=self.learning_rate
)

self.renderer = NeuSRenderer(self.nerf_outside,
self.sdf_network,
self.deviation_network,
self.color_network,
**self.conf['model.neus_renderer'])
self.renderer = NeuSRenderer(
self.nerf_outside, self.sdf_network, self.deviation_network, self.color_network, **self.conf['model.neus_renderer']
)

# Load checkpoint
latest_model_name = None
Expand Down Expand Up @@ -132,7 +135,14 @@ def train(self):
# data = self.dataset.gen_random_rays_at(img_idx, self.batch_size)
data = data.cuda()

rays_o, rays_d, true_rgb, mask, true_normal, cosines = data[:, :3], data[:, 3: 6], data[:, 6: 9], data[:, 9: 10], data[:, 10:13], data[:, 13:]
rays_o, rays_d, true_rgb, mask, true_normal, cosines = (
data[:, :3],
data[:, 3:6],
data[:, 6:9],
data[:, 9:10],
data[:, 10:13],
data[:, 13:],
)
# near, far = self.dataset.near_far_from_sphere(rays_o, rays_d)
near, far = self.dataset.get_near_far()

Expand All @@ -149,9 +159,9 @@ def train(self):
mask = ((mask > 0) & (cosines < -0.1)).to(torch.float32)

mask_sum = mask.sum() + 1e-5
render_out = self.renderer.render(rays_o, rays_d, near, far,
background_rgb=background_rgb,
cos_anneal_ratio=self.get_cos_anneal_ratio())
render_out = self.renderer.render(
rays_o, rays_d, near, far, background_rgb=background_rgb, cos_anneal_ratio=self.get_cos_anneal_ratio()
)

color_fine = render_out['color_fine']
s_val = render_out['s_val']
Expand All @@ -167,41 +177,45 @@ def train(self):
color_errors = (color_fine - true_rgb).abs().sum(dim=1)
color_fine_loss = ranking_loss(color_errors[mask[:, 0] > 0])

psnr = 20.0 * torch.log10(1.0 / (((color_fine - true_rgb)**2 * mask).sum() / (mask_sum * 3.0)).sqrt())
psnr = 20.0 * torch.log10(1.0 / (((color_fine - true_rgb) ** 2 * mask).sum() / (mask_sum * 3.0)).sqrt())

eikonal_loss = gradient_error

# pdb.set_trace()
mask_errors = F.binary_cross_entropy(weight_sum.clip(1e-3, 1.0 - 1e-3), mask, reduction='none')
mask_loss = ranking_loss(mask_errors[:, 0], penalize_ratio=0.8)

def feasible(key): return (key in render_out) and (render_out[key] is not None)
def feasible(key):
return (key in render_out) and (render_out[key] is not None)

# calculate normal loss
n_samples = self.renderer.n_samples + self.renderer.n_importance
normals = render_out['gradients'] * render_out['weights'][:, :n_samples, None]
if feasible('inside_sphere'):
normals = normals * render_out['inside_sphere'][..., None]
normals = normals.sum(dim=1)

# pdb.set_trace()
normal_errors = 1 - F.cosine_similarity(normals, true_normal, dim=1)
# normal_error = normal_error * mask[:, 0]
# normal_loss = F.l1_loss(normal_error, torch.zeros_like(normal_error), reduction='sum') / mask_sum
normal_errors = normal_errors * torch.exp(cosines.abs()[:, 0]) / torch.exp(cosines.abs()).sum()
normal_loss = ranking_loss(normal_errors[mask[:, 0]> 0], penalize_ratio=0.9, type='sum')
normal_loss = ranking_loss(normal_errors[mask[:, 0] > 0], penalize_ratio=0.9, type='sum')

sparse_loss = render_out['sparse_loss']

loss = color_fine_loss * self.color_weight +\
eikonal_loss * self.igr_weight + sparse_loss * self.sparse_weight +\
mask_loss * self.mask_weight + normal_loss * self.normal_weight
loss = (
color_fine_loss * self.color_weight
+ eikonal_loss * self.igr_weight
+ sparse_loss * self.sparse_weight
+ mask_loss * self.mask_weight
+ normal_loss * self.normal_weight
)

self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()


self.writer.add_scalar('Loss/loss', loss, self.iter_step)
self.writer.add_scalar('Loss/color_loss', color_fine_loss, self.iter_step)
self.writer.add_scalar('Loss/eikonal_loss', eikonal_loss, self.iter_step)
Expand All @@ -214,11 +228,17 @@ def feasible(key): return (key in render_out) and (render_out[key] is not None)
print(self.base_exp_dir)
print(
'iter:{:8>d} loss = {:4>f} color_ls = {:4>f} eik_ls = {:4>f} normal_ls = {:4>f} mask_ls = {:4>f} sparse_ls = {:4>f} lr={:5>f}'.format(
self.iter_step, loss, color_fine_loss, eikonal_loss, normal_loss,
mask_loss, sparse_loss, self.optimizer.param_groups[0]['lr']))
print('iter:{:8>d} s_val = {:4>f}'.format( self.iter_step, s_val.mean()))


self.iter_step,
loss,
color_fine_loss,
eikonal_loss,
normal_loss,
mask_loss,
sparse_loss,
self.optimizer.param_groups[0]['lr'],
)
)
print('iter:{:8>d} s_val = {:4>f}'.format(self.iter_step, s_val.mean()))

if self.iter_step % self.val_mesh_freq == 0:
self.validate_mesh(resolution=256)
Expand Down Expand Up @@ -265,7 +285,7 @@ def file_backup(self):
for dir_name in dir_lis:
cur_dir = os.path.join(self.base_exp_dir, 'recording', dir_name)
os.makedirs(cur_dir, exist_ok=True)
files = os.listdir(dir_name)
files = os.listdir(dir_name)
for f_name in files:
if f_name[-3:] == '.py':
copyfile(os.path.join(dir_name, f_name), os.path.join(cur_dir, f_name))
Expand Down Expand Up @@ -318,14 +338,12 @@ def validate_image(self, idx=-1, resolution_level=-1):
near, far = self.dataset.get_near_far()
background_rgb = torch.ones([1, 3]) if self.use_white_bkgd else None

render_out = self.renderer.render(rays_o_batch,
rays_d_batch,
near,
far,
cos_anneal_ratio=self.get_cos_anneal_ratio(),
background_rgb=background_rgb)
render_out = self.renderer.render(
rays_o_batch, rays_d_batch, near, far, cos_anneal_ratio=self.get_cos_anneal_ratio(), background_rgb=background_rgb
)

def feasible(key): return (key in render_out) and (render_out[key] is not None)
def feasible(key):
return (key in render_out) and (render_out[key] is not None)

if feasible('color_fine'):
out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy())
Expand All @@ -336,16 +354,16 @@ def feasible(key): return (key in render_out) and (render_out[key] is not None)
normals = normals * render_out['inside_sphere'][..., None]
normals = normals.sum(dim=1).detach().cpu().numpy()
out_normal_fine.append(normals)

if feasible('weight_sum'):
out_mask.append(render_out['weight_sum'].detach().clip(0, 1).cpu().numpy())

del render_out

img_fine = None
if len(out_rgb_fine) > 0:
img_fine = (np.concatenate(out_rgb_fine, axis=0).reshape([H, W, 3, -1]) * 256).clip(0, 255)

mask_map = None
if len(out_mask) > 0:
mask_map = (np.concatenate(out_mask, axis=0).reshape([H, W, -1]) * 256).clip(0, 255)
Expand All @@ -354,32 +372,31 @@ def feasible(key): return (key in render_out) and (render_out[key] is not None)
if len(out_normal_fine) > 0:
normal_img = np.concatenate(out_normal_fine, axis=0)
rot = np.linalg.inv(self.dataset.pose_all[idx, :3, :3].detach().cpu().numpy())
normal_img = (np.matmul(rot[None, :, :], normal_img[:, :, None])
.reshape([H, W, 3, -1]) * 128 + 128).clip(0, 255)
normal_img = (np.matmul(rot[None, :, :], normal_img[:, :, None]).reshape([H, W, 3, -1]) * 128 + 128).clip(0, 255)

os.makedirs(os.path.join(self.base_exp_dir, 'validations_fine'), exist_ok=True)
os.makedirs(os.path.join(self.base_exp_dir, 'normals'), exist_ok=True)

for i in range(img_fine.shape[-1]):
if len(out_rgb_fine) > 0:
cv.imwrite(os.path.join(self.base_exp_dir,
'validations_fine',
'{:0>8d}_{}_{}.png'.format(self.iter_step, i, idx)),
np.concatenate([img_fine[..., i],
self.dataset.image_at(idx, resolution_level=resolution_level),
self.dataset.mask_at(idx, resolution_level=resolution_level)]))
cv.imwrite(
os.path.join(self.base_exp_dir, 'validations_fine', '{:0>8d}_{}_{}.png'.format(self.iter_step, i, idx)),
np.concatenate(
[
img_fine[..., i],
self.dataset.image_at(idx, resolution_level=resolution_level),
self.dataset.mask_at(idx, resolution_level=resolution_level),
]
),
)
if len(out_normal_fine) > 0:
cv.imwrite(os.path.join(self.base_exp_dir,
'normals',
'{:0>8d}_{}_{}.png'.format(self.iter_step, i, idx)),
np.concatenate([normal_img[..., i],
self.dataset.normal_cam_at(idx, resolution_level=resolution_level)])[:, :, ::-1])
cv.imwrite(
os.path.join(self.base_exp_dir, 'normals', '{:0>8d}_{}_{}.png'.format(self.iter_step, i, idx)),
np.concatenate([normal_img[..., i], self.dataset.normal_cam_at(idx, resolution_level=resolution_level)])[:, :, ::-1],
)
if len(out_mask) > 0:
cv.imwrite(os.path.join(self.base_exp_dir,
'normals',
'{:0>8d}_{}_{}_mask.png'.format(self.iter_step, i, idx)),
mask_map[...,i])

cv.imwrite(os.path.join(self.base_exp_dir, 'normals', '{:0>8d}_{}_{}_mask.png'.format(self.iter_step, i, idx)), mask_map[..., i])

def save_maps(self, idx, img_idx, resolution_level=1):
view_types = ['front', 'back', 'left', 'right']
print('Validate: iter: {}, camera: {}'.format(self.iter_step, idx))
Expand All @@ -398,14 +415,12 @@ def save_maps(self, idx, img_idx, resolution_level=1):
near, far = self.dataset.get_near_far()
background_rgb = torch.ones([1, 3]) if self.use_white_bkgd else None

render_out = self.renderer.render(rays_o_batch,
rays_d_batch,
near,
far,
cos_anneal_ratio=self.get_cos_anneal_ratio(),
background_rgb=background_rgb)
render_out = self.renderer.render(
rays_o_batch, rays_d_batch, near, far, cos_anneal_ratio=self.get_cos_anneal_ratio(), background_rgb=background_rgb
)

def feasible(key): return (key in render_out) and (render_out[key] is not None)
def feasible(key):
return (key in render_out) and (render_out[key] is not None)

if feasible('color_fine'):
out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy())
Expand All @@ -416,16 +431,16 @@ def feasible(key): return (key in render_out) and (render_out[key] is not None)
normals = normals * render_out['inside_sphere'][..., None]
normals = normals.sum(dim=1).detach().cpu().numpy()
out_normal_fine.append(normals)

if feasible('weight_sum'):
out_mask.append(render_out['weight_sum'].detach().clip(0, 1).cpu().numpy())

del render_out

img_fine = None
if len(out_rgb_fine) > 0:
img_fine = (np.concatenate(out_rgb_fine, axis=0).reshape([H, W, 3]) * 256).clip(0, 255)

mask_map = None
if len(out_mask) > 0:
mask_map = (np.concatenate(out_mask, axis=0).reshape([H, W, 1]) * 256).clip(0, 255)
Expand All @@ -439,11 +454,10 @@ def feasible(key): return (key in render_out) and (render_out[key] is not None)
os.makedirs(os.path.join(self.base_exp_dir, 'coarse_maps'), exist_ok=True)
img_rgba = np.concatenate([img_fine[:, :, ::-1], mask_map], axis=-1)
normal_rgba = np.concatenate([world_normal_img[:, :, ::-1], mask_map], axis=-1)

cv.imwrite(os.path.join(self.base_exp_dir, 'coarse_maps', "normals_mlp_%03d_%s.png" % (img_idx, view_types[idx])), img_rgba)
cv.imwrite(os.path.join(self.base_exp_dir, 'coarse_maps', "normals_grad_%03d_%s.png" % (img_idx, view_types[idx])), normal_rgba)


def render_novel_image(self, idx_0, idx_1, ratio, resolution_level):
"""
Interpolate view between two cameras.
Expand All @@ -459,12 +473,9 @@ def render_novel_image(self, idx_0, idx_1, ratio, resolution_level):
near, far = self.dataset.get_near_far()
background_rgb = torch.ones([1, 3]) if self.use_white_bkgd else None

render_out = self.renderer.render(rays_o_batch,
rays_d_batch,
near,
far,
cos_anneal_ratio=self.get_cos_anneal_ratio(),
background_rgb=background_rgb)
render_out = self.renderer.render(
rays_o_batch, rays_d_batch, near, far, cos_anneal_ratio=self.get_cos_anneal_ratio(), background_rgb=background_rgb
)

out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy())

Expand All @@ -477,15 +488,16 @@ def validate_mesh(self, world_space=False, resolution=64, threshold=0.0):
bound_min = torch.tensor(self.dataset.object_bbox_min, dtype=torch.float32)
bound_max = torch.tensor(self.dataset.object_bbox_max, dtype=torch.float32)

vertices, triangles, vertex_colors =\
self.renderer.extract_geometry(bound_min, bound_max, resolution=resolution, threshold=threshold)
vertices, triangles, vertex_colors = self.renderer.extract_geometry(bound_min, bound_max, resolution=resolution, threshold=threshold)
os.makedirs(os.path.join(self.base_exp_dir, 'meshes'), exist_ok=True)

if world_space:
vertices = vertices * self.dataset.scale_mats_np[0][0, 0] + self.dataset.scale_mats_np[0][:3, 3][None]

mesh = trimesh.Trimesh(vertices, triangles, vertex_colors=vertex_colors)
mesh.export(os.path.join(self.base_exp_dir, 'meshes', '{:0>8d}.ply'.format(self.iter_step)))
# mesh.export(os.path.join(self.base_exp_dir, 'meshes', '{:0>8d}.ply'.format(self.iter_step)))
# export as glb
mesh.export(os.path.join(self.base_exp_dir, 'meshes', 'tmp.glb'))

logging.info('End')

Expand All @@ -494,20 +506,15 @@ def interpolate_view(self, img_idx_0, img_idx_1):
n_frames = 60
for i in range(n_frames):
print(i)
images.append(self.render_novel_image(img_idx_0,
img_idx_1,
np.sin(((i / n_frames) - 0.5) * np.pi) * 0.5 + 0.5,
resolution_level=4))
images.append(self.render_novel_image(img_idx_0, img_idx_1, np.sin(((i / n_frames) - 0.5) * np.pi) * 0.5 + 0.5, resolution_level=4))
for i in range(n_frames):
images.append(images[n_frames - i - 1])

fourcc = cv.VideoWriter_fourcc(*'mp4v')
video_dir = os.path.join(self.base_exp_dir, 'render')
os.makedirs(video_dir, exist_ok=True)
h, w, _ = images[0].shape
writer = cv.VideoWriter(os.path.join(video_dir,
'{:0>8d}_{}_{}.mp4'.format(self.iter_step, img_idx_0, img_idx_1)),
fourcc, 30, (w, h))
writer = cv.VideoWriter(os.path.join(video_dir, '{:0>8d}_{}_{}.mp4'.format(self.iter_step, img_idx_0, img_idx_1)), fourcc, 30, (w, h))

for image in images:
writer.write(image)
Expand Down
Loading

0 comments on commit 2c45307

Please sign in to comment.