Skip to content

Commit

Permalink
[UPDATE] add rerun support, use isam by default
Browse files Browse the repository at this point in the history
  • Loading branch information
Yue Pan committed Jul 16, 2024
1 parent 85d6260 commit bb609e9
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 58 deletions.
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,12 @@ python3 pin_slam.py ./config/lidar_slam/run.yaml rosbag point_cloud_topic_name -
python3 pin_slam.py ./config/lidar_slam/run.yaml rosbag -i /path/to/the/rosbag -vsmd
```

The data loaders for [some specific datasets](https://github.com/PRBonn/PIN_SLAM/tree/main/dataset/dataloaders) are also available. For example, you can run on Replica RGB-D dataset without preprocessing the data by:
The data loaders for [some specific datasets](https://github.com/PRBonn/PIN_SLAM/tree/main/dataset/dataloaders) are also available.
```
Available dataloaders: ['apollo', 'boreas', 'generic', 'helipr', 'kitti', 'kitti_raw', 'mcap', 'mulran', 'ncd', 'nclt', 'neuralrgbd', 'nuscenes', 'ouster', 'paris_luco', 'replica', 'rosbag', 'tum']
```

For example, you can run on Replica RGB-D dataset without preprocessing the data by:
```
# Download data
sh scripts/download_replica.sh
Expand All @@ -251,7 +256,7 @@ The SLAM results and logs will be output in the `output_root` folder set in the

For evaluation, you may check [here](https://github.com/PRBonn/PIN_SLAM/blob/main/eval/README.md) for the results that can be obtained with this repository on a couple of popular datasets.

The training logs can be monitored via Weights & Bias online if you set the flag `-w`. If it's your first time using Weights & Bias, you will be requested to register and log in to your wandb account. You can also set the flag `-l` to turn on the log printing in the terminal.
The training logs can be monitored via Weights & Bias online if you set the flag `-w`. If it's your first time using [Weights & Bias](https://wandb.ai/site), you will be requested to register and log in to your wandb account. You can also set the flag `-l` to turn on the log printing in the terminal and set the flag `-r` to turn on the visualization logging by [rerun](https://github.com/rerun-io/rerun).

</details>

Expand Down
3 changes: 0 additions & 3 deletions dataset/dataloaders/rosbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,6 @@ def __init__(self, data_dir: Path, topic: str, *_, **__):
self.msgs = self.bag.messages(connections=connections)
self.timestamps = []

# Visualization Options
self.use_global_visualizer = True

def __del__(self):
if hasattr(self, "bag"):
self.bag.close()
Expand Down
29 changes: 20 additions & 9 deletions dataset/slam_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pathlib import Path
from typing import List

import datetime as dt
import matplotlib.cm as cm
import numpy as np
import open3d as o3d
Expand Down Expand Up @@ -181,9 +182,18 @@ def read_frame_ros(self, msg):

points, point_ts = point_cloud2.read_point_cloud(msg)

if point_ts is None:
if point_ts is not None:
min_timestamp = np.min(point_ts)
max_timestamp = np.max(point_ts)
if min_timestamp == max_timestamp:
point_ts = None
else:
# normalized to 0-1
point_ts = (point_ts - min_timestamp) / (max_timestamp - min_timestamp)

if point_ts is None and not self.config.silence:
print(
"The point cloud message does not contain the time stamp field"
"The point cloud message does not contain the valid time stamp field"
)

self.cur_point_cloud_torch = torch.tensor(
Expand All @@ -193,7 +203,7 @@ def read_frame_ros(self, msg):
if self.config.deskew:
self.get_point_ts(point_ts)

# read frame with specific data loader (borrow from kiss-icp: https://github.com/PRBonn/kiss-icp)
# read frame with specific data loader (partially borrow from kiss-icp: https://github.com/PRBonn/kiss-icp)
def read_frame_with_loader(self, frame_id):

self.set_ref_pose(frame_id)
Expand All @@ -206,12 +216,12 @@ def read_frame_with_loader(self, frame_id):
points, point_ts = data
else:
points = data

self.cur_point_cloud_torch = torch.tensor(points, device=self.device, dtype=self.dtype)

if self.config.deskew:
self.get_point_ts(point_ts)


def read_frame(self, frame_id):

self.set_ref_pose(frame_id)
Expand Down Expand Up @@ -254,7 +264,8 @@ def read_frame(self, frame_id):
# print(self.cur_point_ts_torch)

# point-wise timestamp is now only used for motion undistortion (deskewing)
def get_point_ts(self, point_ts=None):
def get_point_ts(self, point_ts=None):
# point_ts is already the normalized timestamp in a scan frame # [0,1]
if self.config.deskew:
if point_ts is not None and min(point_ts) < 1.0: # not all 1
if not self.silence:
Expand Down Expand Up @@ -344,7 +355,7 @@ def preprocess_frame(self):
# pose initial guess tensor
self.cur_pose_guess_torch = torch.tensor(
cur_pose_init_guess, dtype=torch.float64, device=self.device
)
)

if self.config.adaptive_range_on:
pc_max_bound, _ = torch.max(self.cur_point_cloud_torch[:, :3], dim=0)
Expand Down Expand Up @@ -519,10 +530,10 @@ def update_odom_pose(self, cur_pose_torch: torch.tensor):
self.write_results() # record before the failure point
sys.exit("Lose track for a long time, system failed")

def update_poses_after_pgo(self, pgo_cur_pose, pgo_poses):
self.cur_pose_ref = pgo_cur_pose
self.last_pose_ref = pgo_cur_pose # update for next frame
def update_poses_after_pgo(self, pgo_poses):
self.pgo_poses[:self.processed_frame+1] = pgo_poses # update pgo pose
self.cur_pose_ref = self.pgo_poses[self.processed_frame]
self.last_pose_ref = self.cur_pose_ref # update for next frame

def update_o3d_map(self):

Expand Down
18 changes: 16 additions & 2 deletions pin_slam.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import sys

import rerun as rr
import numpy as np
import open3d as o3d
import torch
Expand Down Expand Up @@ -55,6 +56,7 @@
parser.add_argument('--visualize', '-v', action='store_true', help='Turn on the visualizer')
parser.add_argument('--cpu_only', '-c', action='store_true', help='Run only on CPU')
parser.add_argument('--log_on', '-l', action='store_true', help='Turn on the logs printing')
parser.add_argument('--rerun_on', '-r', action='store_true', help='Turn on the rerun logging')
parser.add_argument('--wandb_on', '-w', action='store_true', help='Turn on the weight & bias logging')
parser.add_argument('--save_map', '-s', action='store_true', help='Save the PIN map after SLAM')
parser.add_argument('--save_mesh', '-m', action='store_true', help='Save the reconstructed mesh after SLAM')
Expand All @@ -80,6 +82,7 @@ def run_pin_slam(config_path=None, dataset_name=None, sequence_name=None, seed=N
config.seed = args.seed
config.silence = not args.log_on
config.wandb_vis_on = args.wandb_on
config.rerun_vis_on = args.rerun_on
config.o3d_vis_on = args.visualize
config.save_map = args.save_map
config.save_mesh = args.save_mesh
Expand All @@ -101,6 +104,9 @@ def run_pin_slam(config_path=None, dataset_name=None, sequence_name=None, seed=N
if config.o3d_vis_on:
o3d_vis = MapVisualizer(config)

if config.rerun_vis_on:
rr.init("pin_slam_rerun_viewer", spawn=True)

# initialize the mlp decoder
geo_mlp = Decoder(config, config.geo_mlp_hidden_dim, config.geo_mlp_level, 1)
sem_mlp = Decoder(config, config.sem_mlp_hidden_dim, config.sem_mlp_level, config.sem_class_count + 1) if config.semantic_on else None
Expand Down Expand Up @@ -250,10 +256,10 @@ def run_pin_slam(config_path=None, dataset_name=None, sequence_name=None, seed=N
# update the neural points and poses
pose_diff_torch = torch.tensor(pgm.get_pose_diff(), device=config.device, dtype=config.dtype)
dataset.cur_pose_torch = torch.tensor(pgm.cur_pose, device=config.device, dtype=config.dtype)
neural_points.adjust_map(pose_diff_torch) # transform neural points (position and orientation) along with associated frame poses
neural_points.adjust_map(pose_diff_torch) # transform neural points (position and orientation) along with associated frame poses # time consuming part
neural_points.recreate_hash(dataset.cur_pose_torch[:3,3], None, (not config.pgo_merge_map), config.rehash_with_time, frame_id) # recreate hash from current time
mapper.transform_data_pool(pose_diff_torch) # transform global pool
dataset.update_poses_after_pgo(pgm.cur_pose, pgm.pgo_poses)
dataset.update_poses_after_pgo(pgm.pgo_poses)
pgm.last_loop_idx = frame_id
pgm.min_loop_idx = min(pgm.min_loop_idx, loop_id)
loop_reg_failed_count = 0
Expand Down Expand Up @@ -374,6 +380,14 @@ def run_pin_slam(config_path=None, dataset_name=None, sequence_name=None, seed=N
loop_edges = pgm.loop_edges_vis if config.pgo_on else None
o3d_vis.update_traj(dataset.cur_pose_ref, odom_poses, gt_poses, pgo_poses, loop_edges)
o3d_vis.update(dataset.cur_frame_o3d, dataset.cur_pose_ref, cur_sdf_slice, cur_mesh, neural_pcd, pool_pcd)

if config.rerun_vis_on:
if neural_pcd is not None:
rr.log("world/neural_points", rr.Points3D(neural_pcd.points, colors=neural_pcd.colors, radii=0.05))
if dataset.cur_frame_o3d is not None:
rr.log("world/input_scan", rr.Points3D(dataset.cur_frame_o3d.points, colors=dataset.cur_frame_o3d.colors, radii=0.03))
if cur_mesh is not None:
rr.log("world/mesh_map", rr.Mesh3D(vertex_positions=cur_mesh.vertices, triangle_indices=cur_mesh.triangles, vertex_normals=cur_mesh.vertex_normals, vertex_colors=cur_mesh.vertex_colors))

T8 = get_time()

Expand Down
2 changes: 1 addition & 1 deletion pin_slam_ros.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def detect_correct_loop(self):
self.neural_points.adjust_map(pose_diff_torch)
self.neural_points.recreate_hash(self.dataset.cur_pose_torch[:3,3], None, (not self.config.pgo_merge_map), self.config.rehash_with_time, cur_frame_id) # recreate hash from current time
self.mapper.transform_data_pool(pose_diff_torch) # transform global pool
self.dataset.update_poses_after_pgo(self.pgm.cur_pose, self.pgm.pgo_poses)
self.dataset.update_poses_after_pgo(self.pgm.pgo_poses)
self.pgm.last_loop_idx = cur_frame_id
self.pgm.min_loop_idx = min(self.pgm.min_loop_idx, loop_id)
self.loop_reg_failed_count = 0
Expand Down
7 changes: 4 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
evo==1.28.0
gnupg==2.3.1
gtsam==4.2
laspy==2.5.3
natsort==8.1.0
open3d==0.17.0
pycryptodomex==3.20.0
pypose==0.6.8
pyquaternion==0.9.9
rerun-sdk==0.17.0
rich==12.5.1
roma==1.5.0
rospkg==1.5.1
scikit-image==0.21.0
wandb==0.17.0
pycryptodomex==3.20.0
gnupg==2.3.1
wandb==0.17.0
3 changes: 2 additions & 1 deletion utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def __init__(self):
# pose graph optimization
self.pgo_on: bool = False
self.pgo_freq: int = 30 # frame interval for detecting loop closure and conducting pose graph optimization after a successful loop correction
self.pgo_with_lm: bool = True # use lm or dogleg optimizer
self.pgo_with_isam: bool = True # use isam incremental optimization or lm batch optimization
self.pgo_max_iter: int = 50 # maximum number of iterations
self.pgo_with_pose_prior: bool = False # use the pose prior or not during the pgo
self.pgo_tran_std: float = 0.04 # m
Expand All @@ -253,6 +253,7 @@ def __init__(self):

# eval
self.wandb_vis_on: bool = False # monitor the training on weight and bias or not
self.rerun_vis_on: bool = False # visualize the process using rerun visualizer or not
self.silence: bool = True # print log in the terminal or not
self.o3d_vis_on: bool = False # visualize the mesh in-the-fly using o3d visualzier or not [press space to pasue/resume]
self.o3d_vis_raw: bool = False # visualize the raw point cloud or the weight source point cloud
Expand Down
81 changes: 44 additions & 37 deletions utils/pgo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from rich import print

from utils.config import Config
from utils.tools import get_time


class PoseGraphManager:
Expand Down Expand Up @@ -43,10 +44,10 @@ def __init__(self, config: Config):
self.robust_loop_cov = gtsam.noiseModel.Robust(mEst, self.loop_cov)
self.robust_odom_cov = gtsam.noiseModel.Robust(mEst, self.odom_cov)

self.graph_factors = (
gtsam.NonlinearFactorGraph()
) # edges # with pose and pose covariance
self.graph_initials = gtsam.Values() # initial guess # as pose
self.isam = gtsam.ISAM2()

self.graph_factors = gtsam.NonlinearFactorGraph() # edges # with pose and pose covariance
self.graph_initials = gtsam.Values() # initial guess of the nodes

self.cur_pose = None
self.curr_node_idx = None
Expand Down Expand Up @@ -168,49 +169,51 @@ def add_loop_factor(
)
) # NOTE: add robust kernel

cur_error = self.graph_factors.error(self.graph_initials)
valid_error_thre = (
self.last_error
+ (cur_id - self.last_loop_idx) * self.config.pgo_error_thre_frame
)
if reject_outlier and cur_error > valid_error_thre:
if not self.silence:
print(
"[bold yellow]A loop edge rejected due to too large error[/bold yellow]"
)
self.graph_factors.remove(self.graph_factors.size() - 1)
return False
if reject_outlier and not self.config.pgo_with_isam:
cur_error = self.graph_factors.error(self.graph_initials)
valid_error_thre = (
self.last_error
+ (cur_id - self.last_loop_idx) * self.config.pgo_error_thre_frame
)
if reject_outlier and cur_error > valid_error_thre:
if not self.silence:
print(
"[bold yellow]A loop edge rejected due to too large error[/bold yellow]"
)
self.graph_factors.remove(self.graph_factors.size() - 1)
return False

return True

def optimize_pose_graph(self):

if self.config.pgo_with_lm:
if self.config.pgo_with_isam:
self.isam.update(self.graph_factors, self.graph_initials)

T_0 = get_time()
self.graph_optimized = self.isam.calculateEstimate()
T_1 = get_time()

else:
opt_param = gtsam.LevenbergMarquardtParams()
opt_param.setMaxIterations(self.config.pgo_max_iter)
opt = gtsam.LevenbergMarquardtOptimizer(
self.graph_factors, self.graph_initials, opt_param
)
else: # pgo with dogleg
opt_param = gtsam.DoglegParams()
opt_param.setMaxIterations(self.config.pgo_max_iter)
opt = gtsam.DoglegOptimizer(
self.graph_factors, self.graph_initials, opt_param
)

error_before = self.graph_factors.error(self.graph_initials)
T_0 = get_time()
self.graph_optimized = opt.optimizeSafely()
T_1 = get_time()

self.graph_optimized = opt.optimizeSafely()

# Calculate marginal covariances for all variables
# marginals = gtsam.Marginals(self.graph_factors, self.graph_optimized)
# try to even visualize the covariance
# cov = get_node_cov(marginals, 50)
# print(cov)

error_after = self.graph_factors.error(self.graph_optimized)
if not self.silence:
print("[bold red]PGO done[/bold red]")
print("error %f --> %f:" % (error_before, error_after))
error_before = self.graph_factors.error(self.graph_initials)
error_after = self.graph_factors.error(self.graph_optimized)
self.last_error = error_after
if not self.silence:
print("[bold red]PGO done[/bold red]")
print("error %f --> %f:" % (error_before, error_after))

# if not self.silence:
# print("time for factor graph optimization (ms)", (T_1-T_0)*1e3)

self.graph_initials = self.graph_optimized # update the initial guess

Expand All @@ -222,7 +225,11 @@ def optimize_pose_graph(self):
self.cur_pose = self.pgo_poses[self.curr_node_idx]

self.pgo_count += 1
self.last_error = error_after

if self.config.pgo_with_isam:
# reset
self.graph_factors = gtsam.NonlinearFactorGraph()
self.graph_initials.clear()

# write the pose graph as the g2o format
def write_g2o(self, out_file):
Expand Down

0 comments on commit bb609e9

Please sign in to comment.