Skip to content

Commit

Permalink
add demo
Browse files Browse the repository at this point in the history
  • Loading branch information
qinzheng93 committed May 10, 2022
1 parent 4ce888d commit 2232c06
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 45 deletions.
1 change: 0 additions & 1 deletion .gitattributes

This file was deleted.

29 changes: 20 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,18 +187,18 @@ Note that the learning rate is multiplied by the number of GPUs by default as th

We evaluate GeoTransformer on the standard 3DMatch/3DLoMatch benchmarks as in [PREDATOR](https://arxiv.org/abs/2011.13005).

| Benchmark | FMR | IR | RR |
|:----------|:----:|:----:|:-----:|
| 3DMatch | 98.2 | 70.9 | 92.5 |
| 3DLoMatch | 87.1 | 43.5 | 74.2 |
| Benchmark | FMR | IR | RR |
| :-------- | :---: | :---: | :---: |
| 3DMatch | 98.2 | 70.9 | 92.5 |
| 3DLoMatch | 87.1 | 43.5 | 74.2 |

### Kitti odometry

We evaluate GeoTransformer on the standard Kitti benchmark as in [PREDATOR](https://arxiv.org/abs/2011.13005).

| Benchmark | RRE | RTE | RR |
|:----------|:-----:|:---:|:----:|
| Kitti | 0.230 | 6.2 | 99.8 |
| Benchmark | RRE | RTE | RR |
| :-------- | :---: | :---: | :---: |
| Kitti | 0.230 | 6.2 | 99.8 |

### ModelNet

Expand All @@ -210,10 +210,22 @@ We evaluate GeoTransformer on ModelNet with two settings:
We remove symmetric classes and use the data augmentation in [RPMNet](https://arxiv.org/abs/2003.13479) which is more difficult than [PRNet](https://arxiv.org/abs/1910.12240).

| Benchmark | RRE | RTE | RMSE |
|:---------------|:-----:|:-----:|:-----:|
| :------------- | :---: | :---: | :---: |
| seen (45-deg) | 1.577 | 0.018 | 0.017 |
| seen (180-deg) | 6.830 | 0.044 | 0.042 |

## Testing on your own data

To test on your own data, the recommended way is to implement a `Dataset` as in `geotransformer.dataset.registration.threedmatch.dataset.py`. Each item in the dataset is a `dict` contains at least 5 keys: `ref_points`, `src_points`, `ref_feats`, `src_feats` and `transform`.

We also provide a demo script to quickly test our pre-trained model on your own data in `experiments/geotransformer.3dmatch.stage4.gse.k3.max.oacl.stage2.sinkhorn/demo.py`. Use the following command to run the demo:

```bash
CUDA_VISIBLE_DEVICES=0 python demo.py --src_file=../../data/demo/src.npy --ref_file=../../data/demo/ref.npy --gt_file=../../data/demo/gt.npy --weights=../../weights/geotransformer-3dmatch.pth.tar
```

Change the arguments `src_file`, `ref_file` and `gt` to your own data, where `src_file` and `ref_file` are numpy files containing a `np.ndarray` in shape of Nx3, and `gt_file` is a numpy file containing a 4x4 transformation matrix. Note that you should scale your data to match the voxel size in 3DMatch (2.5cm).

## Citation

```bibtex
Expand All @@ -235,4 +247,3 @@ We remove symmetric classes and use the data augmentation in [RPMNet](https://ar
- [CoFiNet](https://github.com/haoyu94/Coarse-to-fine-correspondences)
- [huggingface-transformer](https://github.com/huggingface/transformers)
- [SuperGlue](https://github.com/magicleap/SuperGluePretrainedNetwork)

Binary file added data/demo/gt.npy
Binary file not shown.
Binary file added data/demo/ref.npy
Binary file not shown.
Binary file added data/demo/src.npy
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import argparse

import torch
import numpy as np

from geotransformer.utils.data import registration_collate_fn_stack_mode
from geotransformer.utils.torch import to_cuda, release_cuda
from geotransformer.utils.open3d import make_open3d_point_cloud, get_color, draw_geometries
from geotransformer.utils.registration import compute_registration_error

from config import make_cfg
from model import create_model


def make_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--src_file", required=True, help="src point cloud numpy file")
parser.add_argument("--ref_file", required=True, help="src point cloud numpy file")
parser.add_argument("--gt_file", required=True, help="ground-truth transformation file")
parser.add_argument("--weights", required=True, help="model weights file")
return parser


def load_data(args):
src_points = np.load(args.src_file)
ref_points = np.load(args.ref_file)
src_feats = np.ones_like(src_points[:, :1])
ref_feats = np.ones_like(ref_points[:, :1])

data_dict = {
"ref_points": ref_points.astype(np.float32),
"src_points": src_points.astype(np.float32),
"ref_feats": ref_feats.astype(np.float32),
"src_feats": src_feats.astype(np.float32),
}

if args.gt_file is not None:
transform = np.load(args.gt_file)
data_dict["transform"] = transform.astype(np.float32)

return data_dict


def main():
parser = make_parser()
args = parser.parse_args()

cfg = make_cfg()

# prepare data
data_dict = load_data(args)
neighbor_limits = [38, 36, 36, 38] # default setting in 3DMatch
data_dict = registration_collate_fn_stack_mode(
[data_dict], cfg.backbone.num_stages, cfg.backbone.init_voxel_size, cfg.backbone.init_radius, neighbor_limits
)

# prepare model
model = create_model(cfg).cuda()
state_dict = torch.load(args.weights)
model.load_state_dict(state_dict["model"])

# prediction
data_dict = to_cuda(data_dict)
output_dict = model(data_dict)
data_dict = release_cuda(data_dict)
output_dict = release_cuda(output_dict)

# get results
ref_points = output_dict["ref_points"]
src_points = output_dict["src_points"]
estimated_transform = output_dict["estimated_transform"]
transform = data_dict["transform"]

# visualization
ref_pcd = make_open3d_point_cloud(ref_points)
ref_pcd.estimate_normals()
ref_pcd.paint_uniform_color(get_color("custom_yellow"))
src_pcd = make_open3d_point_cloud(src_points)
src_pcd.estimate_normals()
src_pcd.paint_uniform_color(get_color("custom_blue"))
draw_geometries(ref_pcd, src_pcd)
src_pcd = src_pcd.transform(estimated_transform)
draw_geometries(ref_pcd, src_pcd)

# compute error
rre, rte = compute_registration_error(transform, estimated_transform)
print(f"RRE(deg): {rre:.3f}, RTE(m): {rte:.3f}")


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
def train_valid_data_loader(cfg, distributed):
train_dataset = ModelNetPairDataset(
cfg.data.dataset_root,
'train',
"train",
num_points=cfg.data.num_points,
voxel_size=cfg.data.voxel_size,
rotation_magnitude=cfg.data.rotation_magnitude,
Expand Down Expand Up @@ -47,7 +47,7 @@ def train_valid_data_loader(cfg, distributed):

valid_dataset = ModelNetPairDataset(
cfg.data.dataset_root,
'val',
"val",
num_points=cfg.data.num_points,
voxel_size=cfg.data.voxel_size,
rotation_magnitude=cfg.data.rotation_magnitude,
Expand Down Expand Up @@ -82,7 +82,7 @@ def train_valid_data_loader(cfg, distributed):
def test_data_loader(cfg):
train_dataset = ModelNetPairDataset(
cfg.data.dataset_root,
'train',
"train",
num_points=cfg.data.num_points,
voxel_size=cfg.data.voxel_size,
rotation_magnitude=cfg.data.rotation_magnitude,
Expand All @@ -108,7 +108,7 @@ def test_data_loader(cfg):

test_dataset = ModelNetPairDataset(
cfg.data.dataset_root,
'test',
"test",
num_points=cfg.data.num_points,
voxel_size=cfg.data.voxel_size,
rotation_magnitude=cfg.data.rotation_magnitude,
Expand Down Expand Up @@ -146,7 +146,7 @@ def run_test():
import torch

from geotransformer.utils.torch import to_cuda
from geotransformer.utils.open3d import make_open3d_point_cloud, open3d_draw
from geotransformer.utils.open3d import make_open3d_point_cloud, draw_geometries
from geotransformer.modules.ops import get_point_to_node_indices, pairwise_distance, apply_transform
from config import make_cfg

Expand All @@ -155,7 +155,7 @@ def visualize(points_f, points_c):
pcd.paint_uniform_color([0, 0, 1])
ncd = make_open3d_point_cloud(points_c.detach().cpu().numpy())
ncd.paint_uniform_color([1, 0, 0])
open3d_draw(pcd, ncd)
draw_geometries(pcd, ncd)

cfg = make_cfg()
train_loader, val_loader, neighbor_limits = train_valid_data_loader(cfg, False)
Expand All @@ -169,14 +169,14 @@ def visualize(points_f, points_c):
pbar = tqdm(enumerate(val_loader), total=len(val_loader))
for i, data_dict in pbar:
data_dict = to_cuda(data_dict)
ref_length_c = data_dict['lengths'][-1][0].item()
src_length_c = data_dict['lengths'][-1][1].item()
ref_length_f = data_dict['lengths'][0][0].item()
src_length_f = data_dict['lengths'][0][1].item()
transform = data_dict['transform']

points_c = data_dict['points'][-1].detach()
points_f = data_dict['points'][0].detach()
ref_length_c = data_dict["lengths"][-1][0].item()
src_length_c = data_dict["lengths"][-1][1].item()
ref_length_f = data_dict["lengths"][0][0].item()
src_length_f = data_dict["lengths"][0][1].item()
transform = data_dict["transform"]

points_c = data_dict["points"][-1].detach()
points_f = data_dict["points"][0].detach()
ref_points_c = points_c[:ref_length_c]
src_points_c = points_c[ref_length_c:]
ref_points_f = points_f[:ref_length_f]
Expand Down Expand Up @@ -212,27 +212,27 @@ def visualize(points_f, points_c):
all_node_counts.append(src_node_sizes.shape[0])

print(
'matching_counts, mean: {:.3f}, min: {}, max: {}'.format(
"matching_counts, mean: {:.3f}, min: {}, max: {}".format(
np.mean(all_matching_counts), np.min(all_matching_counts), np.max(all_matching_counts)
)
)
print(
'lengths_c, mean: {:.3f}, min: {}, max: {}'.format(
"lengths_c, mean: {:.3f}, min: {}, max: {}".format(
np.mean(all_lengths_c), np.min(all_lengths_c), np.max(all_lengths_c)
)
)
print(
'lengths_f, mean: {:.3f}, min: {}, max: {}'.format(
"lengths_f, mean: {:.3f}, min: {}, max: {}".format(
np.mean(all_lengths_f), np.min(all_lengths_f), np.max(all_lengths_f)
)
)
print(
'node_counts, mean: {:.3f}, min: {}, max: {}'.format(
"node_counts, mean: {:.3f}, min: {}, max: {}".format(
np.mean(all_node_counts), np.min(all_node_counts), np.max(all_node_counts)
)
)
print(
'node_sizes, mean: {:.3f}, min: {}, max: {}'.format(
"node_sizes, mean: {:.3f}, min: {}, max: {}".format(
np.mean(all_node_sizes), np.min(all_node_sizes), np.max(all_node_sizes)
)
)
Expand All @@ -243,5 +243,5 @@ def visualize(points_f, points_c):
print(np.percentile(all_node_sizes, 99))


if __name__ == '__main__':
if __name__ == "__main__":
run_test()
27 changes: 12 additions & 15 deletions geotransformer/utils/open3d.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
import numpy as np
import open3d as o3d
import matplotlib.colors as colors


def get_color(color_name):
if color_name == 'red':
return np.asarray([1.0, 0.0, 0.0])
elif color_name == 'blue':
return np.asarray([0.0, 0.0, 1.0])
elif color_name == 'green':
return np.asarray([0.0, 1.0, 0.0])
elif color_name == 'yellow':
return np.asarray([0.0, 1.0, 1.0])
else:
raise RuntimeError(f'Unsupported color: {color_name}.')
if color_name == "custom_yellow":
return np.asarray([255.0, 204.0, 102.0]) / 255.0
if color_name == "custom_blue":
return np.asarray([102.0, 153.0, 255.0]) / 255.0
assert color_name in colors.CSS4_COLORS
return np.asarray(colors.to_rgb(colors.CSS4_COLORS[color_name]))


def make_scaling_along_axis(points, axis=2, alpha=0):
Expand Down Expand Up @@ -92,7 +89,7 @@ def make_open3d_axis(axis_vector=None, origin=None, scale=1.0):
axes = o3d.geometry.LineSet()
axes.points = o3d.utility.Vector3dVector(points)
axes.lines = o3d.utility.Vector2iVector(line)
axes.paint_uniform_color(get_color('red'))
axes.paint_uniform_color(get_color("red"))
return axes


Expand Down Expand Up @@ -120,16 +117,16 @@ def make_open3d_corr_lines(ref_corr_points, src_corr_points, label):
corr_lines = o3d.geometry.LineSet()
corr_lines.points = o3d.utility.Vector3dVector(corr_points)
corr_lines.lines = o3d.utility.Vector2iVector(corr_indices)
if label == 'pos':
if label == "pos":
corr_lines.paint_uniform_color(np.asarray([0.0, 1.0, 0.0]))
elif label == 'neg':
elif label == "neg":
corr_lines.paint_uniform_color(np.asarray([1.0, 0.0, 0.0]))
else:
raise ValueError('Unsupported `label` {} for correspondences'.format(label))
raise ValueError("Unsupported `label` {} for correspondences".format(label))
return corr_lines


def open3d_draw(*geometries):
def draw_geometries(*geometries):
o3d.visualization.draw_geometries(geometries)


Expand Down

0 comments on commit 2232c06

Please sign in to comment.