Skip to content

Commit

Permalink
neurips init: use my codebase to train on large-scale scenes successf…
Browse files Browse the repository at this point in the history
…ully. (1) add many debug flags (2) fix many small bugs. We should keep these fixs in all following code.
  • Loading branch information
Jinyang Li committed May 6, 2024
1 parent 1e14e40 commit d785baa
Show file tree
Hide file tree
Showing 12 changed files with 427 additions and 116 deletions.
47 changes: 44 additions & 3 deletions arguments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,11 @@ def __init__(self, parser, sentinel=False):
self.quiet = False
self.checkpoint_iterations = []
self.start_checkpoint = ""
self.auto_start_checkpoint = False
self.log_folder = "experiments/default_folder"
self.log_interval = 50
self.debug_why = False
self.llffhold = 10
super().__init__(parser, "Loading Parameters", sentinel)

def extract(self, args):
Expand Down Expand Up @@ -108,6 +110,8 @@ def __init__(self, parser):
self.feature_lr = 0.0025
self.opacity_lr = 0.05
self.scaling_lr = 0.005
self.lr_scale_loss = 1.0
self.lr_scale_pos_and_scale = 1.0
self.rotation_lr = 0.001
self.percent_dense = 0.01
self.lambda_dssim = 0.2
Expand All @@ -116,6 +120,8 @@ def __init__(self, parser):
self.densify_from_iter = 500
self.densify_until_iter = 15_000
self.densify_grad_threshold = 0.0002
self.densify_memory_limit = 17.5
self.opacity_reset_until_iter = -1
self.random_background = False
self.min_opacity = 0.005
self.lr_scale_mode = "linear" # can be "linear", "sqrt", or "accumu"
Expand Down Expand Up @@ -160,6 +166,7 @@ def __init__(self, parser):
self.async_load_gt_image = False
self.multiprocesses_image_loading = False
self.num_train_cameras = -1
self.num_test_cameras = -1
self.distributed_save = False

super().__init__(parser, "Distribution Parameters")
Expand Down Expand Up @@ -196,6 +203,12 @@ def __init__(self, parser):
self.nsys_profile = False # profile with nsys.
self.drop_initial_3dgs_p = 0.0 # profile with nsys.

self.clear_floaters = False # clear floaters in the image.
self.prune_based_on_opacity_interval = 4000
self.sync_more = False
self.log_memory_summary = False
self.empty_cache_more = False

super().__init__(parser, "Debug Parameters")

def get_combined_args(parser : ArgumentParser):
Expand Down Expand Up @@ -235,15 +248,20 @@ def print_all_args(args, log_file):
utils.set_block_size(cuda_block_x, cuda_block_y, one_dim_block_size)
log_file.write("cuda_block_x: {}; cuda_block_y: {}; one_dim_block_size: {};\n".format(cuda_block_x, cuda_block_y, one_dim_block_size))

def find_latest_checkpoint(log_folder):
checkpoint_folder = os.path.join(log_folder, "checkpoints")
if os.path.exists(checkpoint_folder):
all_sub_folders = os.listdir(checkpoint_folder)
if len(all_sub_folders) > 0:
all_sub_folders.sort(key=lambda x: int(x), reverse=True)
return os.path.join(checkpoint_folder, all_sub_folders[0])
return ""

def init_args(args):

# Check arguments
assert not (args.benchmark_stats and args.performance_stats), "benchmark_stats and performance_stats can not be enabled at the same time."

if len(args.save_iterations) > 0 and args.iterations not in args.save_iterations:
args.save_iterations.append(args.iterations)

if args.benchmark_stats:
args.zhx_time = True
args.zhx_python_time = True
Expand Down Expand Up @@ -273,6 +291,12 @@ def init_args(args):
assert not args.save_i2jsend, "performance_stats mode does not support save_i2jsend."
assert not args.stop_update_param, "performance_stats mode does not support stop_update_param."

if args.opacity_reset_until_iter == -1:
args.opacity_reset_until_iter = args.densify_until_iter + args.bsz

if args.auto_start_checkpoint:
args.start_checkpoint = find_latest_checkpoint(args.log_folder)

if args.fixed_training_image != -1:
args.test_iterations = [] # disable testing during training.
args.disable_auto_densification = True
Expand All @@ -285,6 +309,7 @@ def init_args(args):
args.image_distribution = False
args.image_distribution_mode = "0"
args.distributed_dataset_storage = False
args.distributed_save = False

if utils.MP_GROUP.size() == 1:
args.image_distribution_mode = "0"
Expand All @@ -296,5 +321,21 @@ def init_args(args):

# sort test_iterations
args.test_iterations.sort()
if len(args.test_iterations) > 0:
while args.test_iterations[-1] < args.iterations:
args.test_iterations.append(args.test_iterations[-1]+10000)

args.save_iterations.sort()
if len(args.save_iterations) > 0:
while args.save_iterations[-1] < args.iterations:
args.save_iterations.append(args.save_iterations[-1]+20000)
if len(args.save_iterations) > 0 and args.iterations not in args.save_iterations:
args.save_iterations.append(args.iterations)

args.checkpoint_iterations.sort()
if len(args.checkpoint_iterations) > 0:
while args.checkpoint_iterations[-1] < args.iterations:
args.checkpoint_iterations.append(args.checkpoint_iterations[-1]+50000)


init_image_distribution_config(args)
30 changes: 20 additions & 10 deletions gaussian_renderer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,16 +374,26 @@ def render(screenspace_pkg, strategy=None):
# render
if timers is not None:
timers.start("forward_render_gaussians")
rendered_image, n_render, n_consider, n_contrib = screenspace_pkg["rasterizer"].render_gaussians(
means2D=screenspace_pkg["means2D_for_render"],
conic_opacity=screenspace_pkg["conic_opacity_for_render"],
rgb=screenspace_pkg["rgb_for_render"],
depths=screenspace_pkg["depths_for_render"],
radii=screenspace_pkg["radii_for_render"],
compute_locally=compute_locally,
extended_compute_locally=extended_compute_locally,
cuda_args=screenspace_pkg["cuda_args"]
)
if screenspace_pkg["means2D_for_render"].shape[0] < 1000:
# assert utils.get_args().image_distribution_mode == "3", "The image_distribution_mode should be 3."
# rendered_image = torch.zeros((3, screenspace_pkg["rasterizer"].raster_settings.image_height, screenspace_pkg["rasterizer"].raster_settings.image_width), dtype=torch.float32, device="cuda", requires_grad=True)
rendered_image = screenspace_pkg["means2D_for_render"].sum()+screenspace_pkg["conic_opacity_for_render"].sum()+screenspace_pkg["rgb_for_render"].sum()
screenspace_pkg["cuda_args"]["stats_collector"]["forward_render_time"] = 0.0
screenspace_pkg["cuda_args"]["stats_collector"]["backward_render_time"] = 0.0
screenspace_pkg["cuda_args"]["stats_collector"]["forward_loss_time"] = 0.0
screenspace_pkg["cuda_args"]["stats_collector"]["backward_loss_time"] = 0.0
return rendered_image, compute_locally
else:
rendered_image, n_render, n_consider, n_contrib = screenspace_pkg["rasterizer"].render_gaussians(
means2D=screenspace_pkg["means2D_for_render"],
conic_opacity=screenspace_pkg["conic_opacity_for_render"],
rgb=screenspace_pkg["rgb_for_render"],
depths=screenspace_pkg["depths_for_render"],
radii=screenspace_pkg["radii_for_render"],
compute_locally=compute_locally,
extended_compute_locally=extended_compute_locally,
cuda_args=screenspace_pkg["cuda_args"]
)
if timers is not None:
timers.stop("forward_render_gaussians")
utils.check_memory_usage_logging("after forward_render_gaussians")
Expand Down
3 changes: 3 additions & 0 deletions gaussian_renderer/loss_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,4 +1338,7 @@ def replicated_loss_computation(image, viewpoint_cam, compute_locally, strategy,
}

def loss_computation(image, viewpoint_cam, compute_locally, strategy, statistic_collector, image_distribution_mode):
# HACK: if image is a scalar tensor, that implies there is no render. We return 0 to make sure the gradient is also 0.
if len(image.shape) == 0:
return image*0, image*0
return name2loss_implementation[image_distribution_mode](image, viewpoint_cam, compute_locally, strategy, statistic_collector)
6 changes: 3 additions & 3 deletions gaussian_renderer/workload_division.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,9 +473,9 @@ def start_strategy(self):
def finish_strategy(self):
with torch.no_grad():
self.update_heuristic()
if utils.get_args().benchmark_stats:
self.working_strategy.heuristic = None
# Because the heuristic is of size (# of tiles, ) and takes up lots of memory if we keep it for every iteration.
# if utils.get_args().benchmark_stats:
self.working_strategy.heuristic = None
# Because the heuristic is of size (# of tiles, ) and takes up lots of memory if we keep it for every iteration.
self.add(self.working_iteration, self.working_strategy)

def to_json(self):
Expand Down
28 changes: 20 additions & 8 deletions scene/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,19 @@ def __init__(self, args, gaussians : GaussianModel, load_iteration=None, shuffle
utils.log_cpu_memory_usage("before loading images meta data")

if os.path.exists(os.path.join(args.source_path, "sparse")):
scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval)
elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
print("Found transforms_train.json file, assuming Blender data set!")
scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)
scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval, args.llffhold)
elif os.path.exists(os.path.join(args.source_path, "transforms_train_my.json")):
# print("Found transforms_train.json file, assuming Blender data set!")
# scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)
# else:
scene_info = sceneLoadTypeCallbacks["City"](args.source_path,
args.random_background,
args.white_background,
args.eval,
ds=1,
llffhold=args.llffhold)
else:
assert False, "Could not recognize scene type!"
raise ValueError("No valid dataset found in the source path")

if not self.loaded_iter:
with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file:
Expand Down Expand Up @@ -91,7 +98,11 @@ def __init__(self, args, gaussians : GaussianModel, load_iteration=None, shuffle
if args.eval:
for resolution_scale in [args.test_resolution_scale]:
utils.print_rank_0("Decoding Test Cameras")
self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)
if args.num_test_cameras > 0:
test_cameras = scene_info.test_cameras[:args.num_test_cameras]
else:
test_cameras = scene_info.test_cameras
self.test_cameras[resolution_scale] = cameraList_from_camInfos(test_cameras, resolution_scale, args)
# output the number of cameras in the training set and image size to the log file
log_file.write("Test Resolution Scale: {}\n".format(resolution_scale))
log_file.write("Number of local test cameras: {}\n".format(len(self.test_cameras[resolution_scale])))
Expand All @@ -117,12 +128,12 @@ def save(self, iteration):

def getTrainCameras(self, scale=1.0):
if scale not in self.train_cameras:
return None
return []
return self.train_cameras[scale]

def getTestCameras(self, scale=1.0):
if scale not in self.test_cameras:
return None
return []
return self.test_cameras[scale]

def log_scene_info_to_file(self, log_file, prefix_str=""):
Expand Down Expand Up @@ -171,6 +182,7 @@ def get_one_camera(self, batched_cameras_uid):
return viewpoint_cam

def get_batched_cameras(self, batch_size):
assert batch_size <= self.camera_size, "Batch size is larger than the number of cameras in the scene."
batched_cameras = []
batched_cameras_uid = []
for i in range(batch_size):
Expand Down
Loading

0 comments on commit d785baa

Please sign in to comment.