Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into main
Browse files Browse the repository at this point in the history
# Conflicts:
#	dataset/multi_pat_dataset.py
#	exps/exp_xyz2density.py
#	params/xyz2density_train.ini
#	run.sh
  • Loading branch information
CodePointer committed Oct 17, 2022
2 parents 1abf1c8 + a0ad012 commit 7602b9a
Show file tree
Hide file tree
Showing 11 changed files with 177 additions and 88 deletions.
30 changes: 30 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: main.py",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/main.py",
"console": "integratedTerminal",
"args": [
"--config",
"${workspaceFolder}/params/xyz2density_train_lnx.ini",
],
"env": {
"CUDA_VISIBLE_DEVICES": "0",
},
},
{
"name": "Python: Current File",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": false
}
]
}
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"python.pythonPath": "/home/qiao/anaconda3/bin/python"
}
7 changes: 6 additions & 1 deletion args.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,15 @@ def get_args():
help='How many iters for res writing. 0 for no result saving.',
default=0,
type=int)

parser.add_argument('--alpha_stone',
help='alpha value & shrink epoch',
default='0,1.0',
type=str)
parser.add_argument('--patch_rad',
help='Sampled patch radiace for training. Patch side length = 2 * patch_rad + 1',
default=0,
type=int)
parser.add_argument('--pat_set',
help='Pattern number set for training.',
default='',
Expand Down Expand Up @@ -140,7 +145,7 @@ def post_process(args):
# 4. For distributed when you have multiple GPU cards. Only works for linux.
#
torch.cuda.set_device(args.local_rank)
if os.name == 'nt':
if True: # TODO: os.name == 'nt':
args.data_parallel = False
else:
init_process_group('nccl', init_method='env://')
Expand Down
76 changes: 48 additions & 28 deletions dataset/multi_pat_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# - Coding Part - #
class MultiPatDataset(torch.utils.data.Dataset):
"""Load image & pattern for one depth map"""
def __init__(self, scene_folder, pat_idx_set, sample_num=None, calib_para=None, device=None):
def __init__(self, scene_folder, pat_idx_set, sample_num, calib_para, device, rad=0):
self.scene_folder = scene_folder
self.pat_folder = scene_folder / 'pat'
self.img_folder = scene_folder / 'img'
Expand All @@ -36,30 +36,36 @@ def __init__(self, scene_folder, pat_idx_set, sample_num=None, calib_para=None,

self.device = device

# Get coord
self.rays_o = torch.zeros(size=[self.sample_num, 3], device=device)

# Get coord candidate
self.mask_occ = torch.ones([1, *self.img_size], dtype=torch.bool)
if (scene_folder / 'mask' / 'mask_occ.png').exists():
self.mask_occ = plb.imload(scene_folder / 'mask' / 'mask_occ.png').to(torch.bool)
# self.mask_occ = torch.ones([1, *self.img_size], dtype=torch.bool)
# if (scene_folder / 'mask' / 'mask_occ.png').exists():
# self.mask_occ = plb.imload(scene_folder / 'mask' / 'mask_occ.png').to(torch.bool)
# pixel_coord = np.stack(np.meshgrid(np.arange(self.img_size[1]), np.arange(self.img_size[0])), axis=0)
# pixel_coord = plb.a2t(pixel_coord, permute=False).reshape(2, -1)
# self.valid_coord = pixel_coord[:, self.mask_occ.reshape(-1)].to(torch.long) # [2, N]

pch_len = 2 * rad + 1
pixel_coord = np.stack(np.meshgrid(np.arange(self.img_size[1]), np.arange(self.img_size[0])), axis=0)
pixel_coord = plb.a2t(pixel_coord, permute=False).reshape(2, -1)
self.valid_coord = pixel_coord[:, self.mask_occ.reshape(-1)].to(torch.long) # [2, N]
pixel_coord = plb.a2t(pixel_coord, permute=False).unsqueeze(0) # [1, 2, H, W]
pixel_unfold = torch.nn.functional.unfold(pixel_coord.float(), (pch_len, pch_len), padding=rad) # [1, 2 * pch_len**2, H * W]
self.mask_occ = torch.ones([1, *self.img_size], dtype=torch.float)
if (scene_folder / 'mask' / 'mask_occ.png').exists():
self.mask_occ = plb.imload(scene_folder / 'mask' / 'mask_occ.png')
mask_unfold = torch.nn.functional.unfold(self.mask_occ.unsqueeze(0), (pch_len, pch_len), padding=rad) # [1, 2 * pch_len**2, H * W]
mask_bool = self.mask_occ.to(torch.bool).reshape(-1) # [H * W]
self.valid_patch = pixel_unfold[0, :, mask_bool].to(torch.long).reshape(2, pch_len, pch_len, -1).long() # [2, pch_len, pch_len, L]
self.valid_mask = mask_unfold[0, :, mask_bool].reshape(1, pch_len, pch_len, -1) # [1, pch_len, pch_len, L]

# img_hei, img_wid = self.img_set.shape[-2:]
# pixel_coords = np.stack(np.mgrid[:img_hei, :img_wid], axis=-1)[None, ...].astype(np.float32)
# pixel_coords[0, :, :, 0] = pixel_coords[0, :, :, 0] / img_hei
# pixel_coords[0, :, :, 1] = pixel_coords[0, :, :, 1] / img_wid
# pixel_coords -= 0.5
# self.coords = torch.from_numpy(pixel_coords).view(-1, 2)
# self.coords = self.coords.to(device)
# Get coord
self.rays_o = torch.zeros(size=[pch_len ** 2 * self.sample_num, 3], device=device)

self.img_set = self.img_set.to(device)
self.pat_set = self.pat_set.to(device)
# self.depth = self.depth.to(device)
self.mask_occ = self.mask_occ.to(device)
self.valid_coord = self.valid_coord.to(device)
# self.valid_coord = self.valid_coord.to(device)
self.valid_patch = self.valid_patch.to(device)
self.valid_mask = self.valid_mask.to(device)

def __len__(self):
return 1
Expand Down Expand Up @@ -126,20 +132,34 @@ def __getitem__(self, idx):
# Generate random rays from one camera
# img_hei, img_wid = self.img_set.shape[-2:]
rand_args = dict(size=[self.sample_num], device=self.device)
selected_idx = torch.randint(low=0, high=self.valid_coord.shape[1], **rand_args)
selected_coord = self.valid_coord[:, selected_idx]
pixels_x = selected_coord[0]
pixels_y = selected_coord[1]

color = self.img_set[:, pixels_y, pixels_x] # [N, C]
color = color.permute(1, 0)
rays_v = self.pixel2ray(pixels_x, pixels_y)
# # 上一个非patch的
# selected_idx = torch.randint(low=0, high=self.valid_coord.shape[1], **rand_args)
# selected_coord = self.valid_coord[:, selected_idx]

selected_idx = torch.randint(low=0, high=self.valid_patch.shape[-1], **rand_args)
selected_coord = self.valid_patch[:, :, :, selected_idx] # [2, pch_len, pch_len, N]
selected_mask = self.valid_mask[:, :, :, selected_idx] # [1, pch_len, pch_len, N]

pixels_x = selected_coord[0].reshape(-1)
pixels_y = selected_coord[1].reshape(-1)
color = self.img_set[:, pixels_y, pixels_x].permute(1, 0) # [pch_len**2 * N, C]

fx, fy, dx, dy = self.intrinsics
p = torch.stack([
(pixels_x - dx) / fx,
(pixels_y - dy) / fy,
torch.ones_like(pixels_y)
], dim=1) # [pch_len**2 * N, 3]
rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # [pch_len**2 * N, 3]

ret = {
'idx': torch.Tensor([idx]),
'rays_o': self.rays_o, # [N, 3]
'rays_v': rays_v,
'color': color,
'pat': self.pat_set, # [C, Hp, Wp]
'rays_o': self.rays_o, # [N, 3]
'rays_v': rays_v, # [pch_len**2 * N, 3]
'mask': selected_mask.reshape(-1, 1), # [pch_len**2 * N, 1]
'color': color, # [pch_len**2 * N, C]
'pat': self.pat_set, # [C, Hp, Wp]
}

# ret = {
Expand Down
47 changes: 31 additions & 16 deletions exps/exp_xyz2density.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import pointerlib as plb
from dataset.multi_pat_dataset import MultiPatDataset

from loss.loss_func import SuperviseDistLoss
from loss.loss_func import SuperviseDistLoss, NeighborGradientLoss
from networks.layers import WarpFromXyz
from networks.neus import NeuSLRenderer, DensityNetwork, ReflectNetwork

Expand Down Expand Up @@ -48,6 +48,8 @@ def __init__(self, args):
self.alpha_set.append([int(epoch_idx), float(value)])
self.alpha = self.alpha_set[0][1]

self.reg_ratio = 0.0

def init_dataset(self):
"""
Requires:
Expand All @@ -64,7 +66,8 @@ def init_dataset(self):
pat_idx_set=pat_idx_set,
sample_num=self.sample_num,
calib_para=config['Calibration'],
device=self.device
device=self.device,
rad=self.args.patch_rad
)
self.train_dataset = self.pat_dataset

Expand Down Expand Up @@ -137,6 +140,10 @@ def init_losses(self):
"""
self.super_loss = SuperviseDistLoss(dist='l1')
self.loss_funcs['color_l1'] = self.super_loss

if self.args.patch_rad > 0:
self.loss_funcs['gradient'] = NeighborGradientLoss(rad=self.args.patch_rad, dist='l2')

self.logging(f'--loss types: {self.loss_funcs.keys()}')
pass

Expand All @@ -160,18 +167,24 @@ def net_forward(self, data):
The output will be passed to :loss_forward().
"""
render_out = self.renderer.render_density(data['rays_v'], reflect=data['reflect'], alpha=self.alpha)
return render_out['color']
return render_out['color'], render_out['depth']

def loss_forward(self, net_out, data):
"""
How loss functions process the output from network and input data.
The output will be used with err.backward().
"""
color_fine = net_out
color_fine, depth_res = net_out
total_loss = torch.zeros(1).to(self.device)
total_loss += self.loss_record(
'color_l1', pred=color_fine, target=data['color']
)

if 'gradient' in self.loss_funcs:
total_loss += self.loss_record(
'gradient', depth=depth_res, mask=data['mask']
) * self.reg_ratio

self.avg_meters['Total'].update(total_loss, self.N)
return total_loss

Expand All @@ -182,6 +195,8 @@ def callback_after_train(self, epoch):
if alpha_pair[0] > epoch:
break
self.alpha = alpha_pair[1]
if epoch > 1000:
self.reg_ratio = 0.01

def callback_save_res(self, data, net_out, dataset, res_writer):
"""
Expand Down Expand Up @@ -229,8 +244,7 @@ def callback_epoch_report(self, epoch, tag, stopwatch, res_writer=None):
# Save
if self.args.save_stone > 0 and epoch % self.args.save_stone == 0:
res = self.visualize_output(resolution_level=1, require_item=[
'img_list', 'wrp_viz', 'depth_viz', 'depth_map', 'point_cloud',
'query_density', 'query_z', 'query_weights'
'wrp_viz', 'depth_viz', 'depth_map', 'point_cloud'
]) # TODO: 可视化所有的query部分,用于绘制结果图。
save_folder = self.res_dir / 'output' / f'e_{epoch:05}'
save_folder.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -279,10 +293,10 @@ def require_contain(*keys):
img_hei, img_wid = img_size
total_ray = rays_o.shape[0]
idx = torch.arange(0, total_ray, dtype=torch.long)
idx = idx[mask_val]
rays_o = rays_o[mask_val]
rays_d = rays_d[mask_val]
reflect = reflect[mask_val]
idx = idx[mask_val > 0.0]
rays_o = rays_o[mask_val > 0.0]
rays_d = rays_d[mask_val > 0.0]
reflect = reflect[mask_val > 0.0]

rays_o_set = rays_o.split(self.sample_num)
rays_d_set = rays_d.split(self.sample_num)
Expand All @@ -302,12 +316,13 @@ def require_contain(*keys):
out_rgb_fine.append(color_fine.detach().cpu())

if require_contain('depth_map', 'depth_viz', 'point_cloud', 'mesh'):
weights = render_out['weights']
mid_z_vals = render_out['pts'][:, :, -1]
max_idx = torch.argmax(weights, dim=1) # [N]
mid_z = mid_z_vals[torch.arange(max_idx.shape[0]), max_idx]
# mid_z = render_out['z_val']
out_depth.append(mid_z.detach().cpu())
# weights = render_out['weights']
# mid_z_vals = render_out['pts'][:, :, -1]
# max_idx = torch.argmax(weights, dim=1) # [N]
# mid_z = mid_z_vals[torch.arange(max_idx.shape[0]), max_idx]
# out_depth.append(mid_z.detach().cpu())
depth_val = render_out['depth'].reshape(-1)
out_depth.append(depth_val.detach().cpu())

if require_contain('query_z'):
out_z.append(render_out['z_vals'].detach().cpu())
Expand Down
39 changes: 39 additions & 0 deletions loss/loss_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,42 @@ def forward(self, pred, target, mask=None):
mask = torch.ones_like(pred)
val = (err_map * mask).sum() / (mask.sum() + 1e-8)
return val, err_map


class NeighborGradientLoss(BaseLoss):
def __init__(self, rad, name='NeighborGradientLoss', dist='l2'):
super().__init__(name)
assert rad > 0
self.rad = rad
self.crit = None
if dist == 'l1':
self.crit = torch.nn.L1Loss(reduction='none')
elif dist == 'l2':
self.crit = torch.nn.MSELoss(reduction='none')
elif dist == 'smoothl1':
self.crit = torch.nn.SmoothL1Loss(reduction='none')
else:
raise NotImplementedError(f'Unknown loss type: {dist}')

def forward(self, depth, mask):
"""
depth: [pch_len**2 * N, 1]
mask: [pch_len**2 * N, 1]
"""
pch_len = self.rad * 2 + 1
depth_patch = depth.reshape(pch_len, pch_len, -1, 1)
mask_patch = mask.reshape(pch_len, pch_len, -1, 1)

x_grad = depth_patch[:-1, 1:, :, :] - depth_patch[:-1, :-1, :, :]
x_error = self.crit(x_grad, torch.zeros_like(x_grad))
x_mask = mask_patch[:-1, 1:, :, :] * mask_patch[:-1, :-1, :, :]

y_grad = depth_patch[1:, :-1, :, :] - depth_patch[:-1, :-1, :, :]
y_error = self.crit(y_grad, torch.zeros_like(y_grad))
y_mask = mask_patch[1:, :-1, :, :] * mask_patch[:-1, :-1, :, :]

grad_error = (x_error * x_mask + y_error * y_mask).sum()
grad_base = (x_mask + y_mask).sum() + 1e-8

val = grad_error / grad_base
return val
4 changes: 4 additions & 0 deletions networks/neus.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,10 +965,14 @@ def render_density(self, rays_d, reflect, alpha):
weights = self.density2weights(z_vals=z_vals, density=density)
color = (sampled_color * weights[:, :, None]).sum(dim=1)

# Compute depth
depth_val = (pts[:, :, -1:] * weights[:, :, None]).sum(dim=1)

return {
'pts': pts,
'pt_color': sampled_color,
'color': color,
'depth': depth_val,
'density': density,
'z_vals': z_vals,
# 'z_val': z_val,
Expand Down
18 changes: 9 additions & 9 deletions params/xyz2density_train.ini
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,23 @@ out_dir = C:/SLDataSet/20221005realsyn-out
;test_dir =
model_dir =
exp_type = train
run_tag =
;debug_mode = True
run_tag = Discard
debug_mode = True

batch_num = 512
num_workers = 0
epoch_start = 0
epoch_end = 4001
epoch_end = 32001
; remove_history = True
lr = 1e-3
lr_step = 0
report_stone = 10
img_stone = 100
model_stone = 100
save_stone = 100
report_stone = 100
img_stone = 200
model_stone = 200
save_stone = 1000

;alpha_stone = 0-1.1,1000-0.5,2000-0.2,3000-0.1
alpha_stone = 0-1.1
alpha_stone = 0-2.1,10000-1.1,20000-0.5,30000-0.2
patch_rad = 0
pat_set = 0,1,2,3,4,5,6,40,41

;num_workers = 0
Expand Down
Loading

0 comments on commit 7602b9a

Please sign in to comment.