Skip to content

Commit

Permalink
reorganize code; support Kitti & ModelNet
Browse files Browse the repository at this point in the history
  • Loading branch information
qinzheng93 committed Apr 2, 2022
1 parent d14834c commit 73e1439
Show file tree
Hide file tree
Showing 311 changed files with 10,989 additions and 6,783 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.pth.tar filter=lfs diff=lfs merge=lfs -text
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.idea
.vscode
141 changes: 126 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ We study the problem of extracting accurate correspondences for point cloud regi

## News

2022.03.30: Code and pretrained models on KITTI and ModelNet40 release.

2022.03.29: This work is selected for an **ORAL** presentation at CVPR 2022.

2022.03.02: This work is accepted by CVPR 2022. Code and Models on ModelNet40 and KITTI will be released soon.

2022.02.15: Paper is available at [arXiv](https://arxiv.org/abs/2202.06688).
Expand All @@ -33,20 +37,17 @@ conda activate geotransformer
pip install torch==1.7.1+cu110 -f https://download.pytorch.org/whl/torch_stable.html

# Install packages and other dependencies
pip install -r requirements.txt
python setup.py build develop

# Compile c++ wrappers
cd geotransformer/cpp_wrappers
sh ./compile_wrappers.sh
```

Code has been tested with Ubuntu 20.04, GCC 9.3.0, Python 3.8, PyTorch 1.7.1, CUDA 11.1 and cuDNN 8.1.0.

## Data preparation
## 3DMatch

We provide code for training and testing on 3DMatch.
### Data preparation

The dataset can be download from [PREDATOR](https://github.com/overlappredator/OverlapPredator). The data should be organized as follows:
The dataset can be downloaded from [PREDATOR](https://github.com/prs-eth/OverlapPredator). The data should be organized as follows:

```text
--data--3DMatch--metadata
Expand All @@ -57,14 +58,12 @@ The dataset can be download from [PREDATOR](https://github.com/overlappredator/O
| |--...
|--...
```

## Training

The code for GeoTransformer is in `experiments/geotransformer.3dmatch`. Use the following command for training.
The code for 3DMatch is in `experiments/geotransformer.3dmatch.stage4.gse.k3.max.oacl.stage2.sinkhorn`. Use the following command for training.

```bash
CUDA_VISIBLE_DEVICES=0 python trainval.py
# use "--snapshot=path/to/snapshot" to resume training.
```

## Testing
Expand All @@ -84,18 +83,121 @@ We also provide pretrained weights in `weights`, use the following command to te

```bash
CUDA_VISIBLE_DEVICES=0 python test.py --snapshot=../../weights/geotransformer-3dmatch.pth.tar --benchmark=3DMatch
CUDA_VISIBLE_DEVICES=0 python eval.py --run_matching --run_registration --benchmark=3DMatch
CUDA_VISIBLE_DEVICES=0 python eval.py --benchmark=3DMatch --method=lgr
```

Replace `3DMatch` with `3DLoMatch` to evaluate on 3DLoMatch.

## Kitti odometry

### Data preparation

Download the data from the [Kitti official website](http://www.cvlibs.net/datasets/kitti/eval_odometry.php) into `data/Kitti` and run `data/Kitti/downsample_pcd.py` to generate the data. The data should be organized as follows:

```text
--data--Kitti--metadata
|--sequences--00--velodyne--000000.bin
| | |--...
| |...
|--downsampled--00--000000.npy
| |--...
|--...
```

### Training

The code for Kitti is in `experiments/geotransformer.kitti.stage5.gse.k3.max.oacl.stage2.sinkhorn`. Use the following command for training.

```bash
CUDA_VISIBLE_DEVICES=0 python trainval.py
```

## Testing

Use the following command for testing.

```bash
CUDA_VISIBLE_DEVICES=0 ./eval.sh EPOCH
```

`EPOCH` is the epoch id.

We also provide pretrained weights in `weights`, use the following command to test the pretrained weights.

```bash
CUDA_VISIBLE_DEVICES=0 python test.py --snapshot=../../weights/geotransformer-kitti.pth.tar
CUDA_VISIBLE_DEVICES=0 python eval.py --method=lgr
```

## ModelNet

### Data preparation

Download the [data](https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip) and run `data/ModelNet/split_data.py` to generate the data. The data should be organized as follows:

```text
--data--ModelNet--modelnet_ply_hdf5_2048--...
|--train.pkl
|--val.pkl
|--test.pkl
```

### Training

The code for ModelNet is in `experiments/geotransformer.modelnet.rpmnet.stage4.gse.k3.max.oacl.stage2.sinkhorn`. Use the following command for training.

```bash
CUDA_VISIBLE_DEVICES=0 python trainval.py
```

## Testing

Use the following command for testing.

```bash
CUDA_VISIBLE_DEVICES=0 python test.py --test_iter=ITER
```

`ITER` is the iteration id.

We also provide pretrained weights in `weights`, use the following command to test the pretrained weights.

```bash
CUDA_VISIBLE_DEVICES=0 python test.py --snapshot=../../weights/geotransformer-modelnet.pth.tar
```

## Results

| Benchmark | FMR | IR | RR |
| --------- | --- | -- | -- |
| 3DMatch | 97.7 | 70.3 | 91.5 |
| 3DLoMatch | 88.1 | 43.3 | 74.0 |
### 3DMatch

We evaluate GeoTransformer on the standard 3DMatch/3DLoMatch benchmarks as in [PREDATOR](https://arxiv.org/abs/2011.13005).

| Benchmark | FMR | IR | RR |
|:----------|:----:|:----:|:-----:|
| 3DMatch | 98.2 | 70.9 | 92.5 |
| 3DLoMatch | 87.1 | 43.5 | 74.2 |

### Kitti odometry

We evaluate GeoTransformer on the standard Kitti benchmark as in [PREDATOR](https://arxiv.org/abs/2011.13005).

| Benchmark | RRE | RTE | RR |
|:----------|:-----:|:---:|:----:|
| Kitti | 0.230 | 6.2 | 99.8 |

### ModelNet

We evaluate GeoTransformer on ModelNet with two settings:

1. Standard setting: [0, 45] rotation, [-0.5, 0.5] translation, gaussian noise clipped to 0.05.
2. Full-range setting: [0, 180] rotation, [-0.5, 0.5] translation, gaussian noise clipped to 0.05.

We remove symmetric classes and use the data augmentation in [RPMNet](https://arxiv.org/abs/2003.13479) which is more difficult than [PRNet](https://arxiv.org/abs/1910.12240).

| Benchmark | RRE | RTE | RMSE |
|:--------------------|:-----:|:-----:|:-----:|
| seen (45$^{circ}$) | 1.577 | 0.018 | 0.017 |
| seen (180$^{circ}$) | 6.830 | 0.044 | 0.042 |

## Citation

Expand All @@ -109,3 +211,12 @@ Replace `3DMatch` with `3DLoMatch` to evaluate on 3DLoMatch.
primaryClass={cs.CV}
}
```

## Acknowledgements

- [D3Feat](https://github.com/XuyangBai/D3Feat.pytorch)
- [PREDATOR](https://github.com/prs-eth/OverlapPredator)
- [RPMNet](https://github.com/yewzijian/RPMNet)
- [CoFiNet](https://github.com/haoyu94/Coarse-to-fine-correspondences)
- [huggingface-transformer](https://github.com/huggingface/transformers)
- [SuperGlue](https://github.com/magicleap/SuperGluePretrainedNetwork)
26 changes: 26 additions & 0 deletions data/Kitti/downsample_pcd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os
import os.path as osp
import open3d as o3d
import numpy as np
import glob
from tqdm import tqdm


def main():
for i in range(11):
seq_id = '{:02d}'.format(i)
file_names = glob.glob(osp.join('sequences', seq_id, 'velodyne', '*.bin'))
for file_name in tqdm(file_names):
frame = file_name.split('/')[-1][:-4]
new_file_name = osp.join('downsampled', seq_id, frame + '.npy')
points = np.fromfile(file_name, dtype=np.float32).reshape(-1, 4)
points = points[:, :3]
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)
pcd = pcd.voxel_down_sample(0.3)
points = np.array(pcd.points).astype(np.float32)
np.save(new_file_name, points)


if __name__ == '__main__':
main()
Binary file added data/Kitti/metadata/test.pkl
Binary file not shown.
Binary file added data/Kitti/metadata/train.pkl
Binary file not shown.
Binary file added data/Kitti/metadata/val.pkl
Binary file not shown.
47 changes: 47 additions & 0 deletions data/ModelNet/split_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import h5py
import numpy as np
import pickle


def dump_pickle(data, filename):
with open(filename, 'wb') as f:
pickle.dump(data, f)


def process(subset):
with open(f'modelnet40_ply_hdf5_2048/{subset}_files.txt') as f:
lines = f.readlines()
all_points = []
all_normals = []
all_labels = []
for line in lines:
filename = line.strip()
h5file = h5py.File(f'modelnet40_ply_hdf5_2048/{filename}', 'r')
all_points.append(h5file['data'][:])
all_normals.append(h5file['normal'][:])
all_labels.append(h5file['label'][:].flatten().astype(np.int))
points = np.concatenate(all_points, axis=0)
normals = np.concatenate(all_normals, axis=0)
labels = np.concatenate(all_labels, axis=0)
print(f'{subset} data loaded.')
all_data = []
num_data = points.shape[0]
for i in range(num_data):
all_data.append(dict(points=points[i], normals=normals[i], label=labels[i]))
if subset == 'train':
indices = np.random.permutation(num_data)
num_train = int(num_data * 0.8)
num_val = num_data - num_train
train_indices = indices[:num_train]
val_indices = indices[num_train:]
train_data = [all_data[i] for i in train_indices.tolist()]
dump_pickle(train_data, 'train.pkl')
val_data = [all_data[i] for i in val_indices.tolist()]
dump_pickle(val_data, 'val.pkl')
else:
dump_pickle(all_data, 'test.pkl')



for subset in ['train', 'test']:
process(subset)
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import torch
import torch.nn as nn
from IPython import embed

from geotransformer.modules.kpconv import ConvBlock, ResidualBlock, UnaryBlock, LastUnaryBlock, nearest_upsample


class KPConvFPN(nn.Module):
def __init__(self, input_dim, output_dim, init_dim, kernel_size, init_radius, init_sigma, group_norm):
super(KPConvFPN, self).__init__()

self.encoder1_1 = ConvBlock(input_dim, init_dim, kernel_size, init_radius, init_sigma, group_norm)
self.encoder1_2 = ResidualBlock(init_dim, init_dim * 2, kernel_size, init_radius, init_sigma, group_norm)

self.encoder2_1 = ResidualBlock(
init_dim * 2, init_dim * 2, kernel_size, init_radius, init_sigma, group_norm, strided=True
)
self.encoder2_2 = ResidualBlock(
init_dim * 2, init_dim * 4, kernel_size, init_radius * 2, init_sigma * 2, group_norm
)
self.encoder2_3 = ResidualBlock(
init_dim * 4, init_dim * 4, kernel_size, init_radius * 2, init_sigma * 2, group_norm
)

self.encoder3_1 = ResidualBlock(
init_dim * 4, init_dim * 4, kernel_size, init_radius * 2, init_sigma * 2, group_norm, strided=True
)
self.encoder3_2 = ResidualBlock(
init_dim * 4, init_dim * 8, kernel_size, init_radius * 4, init_sigma * 4, group_norm
)
self.encoder3_3 = ResidualBlock(
init_dim * 8, init_dim * 8, kernel_size, init_radius * 4, init_sigma * 4, group_norm
)

self.encoder4_1 = ResidualBlock(
init_dim * 8, init_dim * 8, kernel_size, init_radius * 4, init_sigma * 4, group_norm, strided=True
)
self.encoder4_2 = ResidualBlock(
init_dim * 8, init_dim * 16, kernel_size, init_radius * 8, init_sigma * 8, group_norm
)
self.encoder4_3 = ResidualBlock(
init_dim * 16, init_dim * 16, kernel_size, init_radius * 8, init_sigma * 8, group_norm
)

self.decoder3 = UnaryBlock(init_dim * 24, init_dim * 8, group_norm)
self.decoder2 = LastUnaryBlock(init_dim * 12, output_dim)

def forward(self, feats, data_dict):
feats_list = []

points_list = data_dict['points']
neighbors_list = data_dict['neighbors']
subsampling_list = data_dict['subsampling']
upsampling_list = data_dict['upsampling']

feats_s1 = feats
feats_s1 = self.encoder1_1(feats_s1, points_list[0], points_list[0], neighbors_list[0])
feats_s1 = self.encoder1_2(feats_s1, points_list[0], points_list[0], neighbors_list[0])

feats_s2 = self.encoder2_1(feats_s1, points_list[1], points_list[0], subsampling_list[0])
feats_s2 = self.encoder2_2(feats_s2, points_list[1], points_list[1], neighbors_list[1])
feats_s2 = self.encoder2_3(feats_s2, points_list[1], points_list[1], neighbors_list[1])

feats_s3 = self.encoder3_1(feats_s2, points_list[2], points_list[1], subsampling_list[1])
feats_s3 = self.encoder3_2(feats_s3, points_list[2], points_list[2], neighbors_list[2])
feats_s3 = self.encoder3_3(feats_s3, points_list[2], points_list[2], neighbors_list[2])

feats_s4 = self.encoder4_1(feats_s3, points_list[3], points_list[2], subsampling_list[2])
feats_s4 = self.encoder4_2(feats_s4, points_list[3], points_list[3], neighbors_list[3])
feats_s4 = self.encoder4_3(feats_s4, points_list[3], points_list[3], neighbors_list[3])

latent_s4 = feats_s4
feats_list.append(feats_s4)

latent_s3 = nearest_upsample(latent_s4, upsampling_list[2])
latent_s3 = torch.cat([latent_s3, feats_s3], dim=1)
latent_s3 = self.decoder3(latent_s3)
feats_list.append(latent_s3)

latent_s2 = nearest_upsample(latent_s3, upsampling_list[1])
latent_s2 = torch.cat([latent_s2, feats_s2], dim=1)
latent_s2 = self.decoder2(latent_s2)
feats_list.append(latent_s2)

feats_list.reverse()

return feats_list
Loading

0 comments on commit 73e1439

Please sign in to comment.