Skip to content

Commit

Permalink
[Feature] Support 3D semantic segmentation demo (open-mmlab#532)
Browse files Browse the repository at this point in the history
* compress kitti unit test imgs

* add unit test for inference_multi_modality_detector

* fix typos

* rename init_detector to init_model

* show_result_meshlab support seg task

* add unit test for seg show_result_meshlab

* support inference_segmentor

* support pc seg demo

* add docs

* minor fix

* change function name

* compress demo data size

* update docs
  • Loading branch information
Wuziyi616 authored May 19, 2021
1 parent b84111d commit 3bac800
Show file tree
Hide file tree
Showing 13 changed files with 325 additions and 61 deletions.
Binary file added demo/data/scannet/scene0000_00.bin
Binary file not shown.
7 changes: 4 additions & 3 deletions demo/multi_modality_demo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from argparse import ArgumentParser

from mmdet3d.apis import (inference_multi_modality_detector, init_detector,
from mmdet3d.apis import (inference_multi_modality_detector, init_model,
show_result_meshlab)


Expand All @@ -26,7 +26,7 @@ def main():
args = parser.parse_args()

# build the model from a config file and a checkpoint file
model = init_detector(args.config, args.checkpoint, device=args.device)
model = init_model(args.config, args.checkpoint, device=args.device)
# test a single image
result, data = inference_multi_modality_detector(model, args.pcd,
args.image, args.ann)
Expand All @@ -37,7 +37,8 @@ def main():
args.out_dir,
args.score_thr,
show=args.show,
snapshot=args.snapshot)
snapshot=args.snapshot,
task='multi_modality-det')


if __name__ == '__main__':
Expand Down
39 changes: 39 additions & 0 deletions demo/pc_seg_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from argparse import ArgumentParser

from mmdet3d.apis import inference_segmentor, init_model, show_result_meshlab


def main():
parser = ArgumentParser()
parser.add_argument('pcd', help='Point cloud file')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--out-dir', type=str, default='demo', help='dir to save results')
parser.add_argument(
'--show', action='store_true', help='show online visuliaztion results')
parser.add_argument(
'--snapshot',
action='store_true',
help='whether to save online visuliaztion results')
args = parser.parse_args()

# build the model from a config file and a checkpoint file
model = init_model(args.config, args.checkpoint, device=args.device)
# test a single image
result, data = inference_segmentor(model, args.pcd)
# show the results
show_result_meshlab(
data,
result,
args.out_dir,
show=args.show,
snapshot=args.snapshot,
task='seg',
palette=model.PALETTE)


if __name__ == '__main__':
main()
7 changes: 4 additions & 3 deletions demo/pcd_demo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from argparse import ArgumentParser

from mmdet3d.apis import inference_detector, init_detector, show_result_meshlab
from mmdet3d.apis import inference_detector, init_model, show_result_meshlab


def main():
Expand All @@ -23,7 +23,7 @@ def main():
args = parser.parse_args()

# build the model from a config file and a checkpoint file
model = init_detector(args.config, args.checkpoint, device=args.device)
model = init_model(args.config, args.checkpoint, device=args.device)
# test a single image
result, data = inference_detector(model, args.pcd)
# show the results
Expand All @@ -33,7 +33,8 @@ def main():
args.out_dir,
args.score_thr,
show=args.show,
snapshot=args.snapshot)
snapshot=args.snapshot,
task='det')


if __name__ == '__main__':
Expand Down
32 changes: 25 additions & 7 deletions docs/0_demo.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,21 @@

## Introduction

We provide scipts for multi-modality/single-modality and indoor/outdoor 3D detection demos. The pre-trained models can be downloaded from [model zoo](https://github.com/open-mmlab/mmdetection3d/blob/master/docs/model_zoo.md/). We provide pre-processed sample data from KITTI and SUN RGB-D dataset. You can use any other data following our pre-processing steps.
We provide scipts for multi-modality/single-modality, indoor/outdoor 3D detection and 3D semantic segmentation demos. The pre-trained models can be downloaded from [model zoo](https://github.com/open-mmlab/mmdetection3d/blob/master/docs/model_zoo.md/). We provide pre-processed sample data from KITTI, SUN RGB-D and ScanNet dataset. You can use any other data following our pre-processing steps.

## Testing

### Single-modality demo
### 3D Detection

#### Single-modality demo

To test a 3D detector on point cloud data, simply run:

```shell
python demo/pcd_demo.py ${PCD_FILE} ${CONFIG_FILE} ${CHECKPOINT_FILE} [--device ${GPU_ID}] [--score-thr ${SCORE_THR}] [--out-dir ${OUT_DIR}]
python demo/pcd_demo.py ${PCD_FILE} ${CONFIG_FILE} ${CHECKPOINT_FILE} [--device ${GPU_ID}] [--score-thr ${SCORE_THR}] [--out-dir ${OUT_DIR}] [--show]
```

The visualization results including a point cloud and predicted 3D bounding boxes will be saved in ```demo/PCD_NAME```, which you can open using [MeshLab](http://www.meshlab.net/).
The visualization results including a point cloud and predicted 3D bounding boxes will be saved in `${OUT_DIR}/PCD_NAME`, which you can open using [MeshLab](http://www.meshlab.net/). Note that if you set the flag `--show`, the prediction result will be displayed online using [Open3D](http://www.open3d.org/).

Example on KITTI data using [SECOND](https://github.com/open-mmlab/mmdetection3d/tree/master/configs/second) model:

Expand All @@ -30,15 +32,15 @@ python demo/pcd_demo.py demo/data/sunrgbd/sunrgbd_000017.bin configs/votenet/vot

Remember to convert the VoteNet checkpoint if you are using mmdetection3d version >= 0.6.0. See its [README](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/votenet/README.md/) for detailed instructions on how to convert the checkpoint.

### Multi-modality demo
#### Multi-modality demo

To test a 3D detector on multi-modality data (typically point cloud and image), simply run:

```shell
python demo/multi_modality_demo.py ${PCD_FILE} ${IMAGE_FILE} ${ANNOTATION_FILE} ${CONFIG_FILE} ${CHECKPOINT_FILE} [--device ${GPU_ID}] [--score-thr ${SCORE_THR}] [--out-dir ${OUT_DIR}]
python demo/multi_modality_demo.py ${PCD_FILE} ${IMAGE_FILE} ${ANNOTATION_FILE} ${CONFIG_FILE} ${CHECKPOINT_FILE} [--device ${GPU_ID}] [--score-thr ${SCORE_THR}] [--out-dir ${OUT_DIR}] [--show]
```

where the ```ANNOTATION_FILE``` should provide the 3D to 2D projection matrix. The visualization results including a point cloud, an image, predicted 3D bounding boxes and their projection on the image will be saved in ```demo/PCD_NAME```.
where the `ANNOTATION_FILE` should provide the 3D to 2D projection matrix. The visualization results including a point cloud, an image, predicted 3D bounding boxes and their projection on the image will be saved in `${OUT_DIR}/PCD_NAME`.

Example on KITTI data using [MVX-Net](https://github.com/open-mmlab/mmdetection3d/tree/master/configs/mvxnet) model:

Expand All @@ -51,3 +53,19 @@ Example on SUN RGB-D data using [ImVoteNet](https://github.com/open-mmlab/mmdete
```shell
python demo/multi_modality_demo.py demo/data/sunrgbd/sunrgbd_000017.bin demo/data/sunrgbd/sunrgbd_000017.jpg demo/data/sunrgbd/sunrgbd_000017_infos.pkl configs/imvotenet/imvotenet_stage2_16x8_sunrgbd-3d-10class.py checkpoints/imvotenet_stage2_16x8_sunrgbd-3d-10class_20210323_184021-d44dcb66.pth
```

### 3D Segmentation

To test a 3D segmentor on point cloud data, simply run:

```shell
python demo/pc_seg_demo.py ${PCD_FILE} ${CONFIG_FILE} ${CHECKPOINT_FILE} [--device ${GPU_ID}] [--out-dir ${OUT_DIR}] [--show]
```

The visualization results including a point cloud and its predicted 3D segmentation mask will be saved in `${OUT_DIR}/PCD_NAME`.

Example on ScanNet data using [PointNet++ (SSG)](https://github.com/open-mmlab/mmdetection3d/tree/master/configs/pointnet2) model:

```shell
python demo/pc_seg_demo.py demo/data/scannet/scene0000_00.bin configs/pointnet2/pointnet2_ssg_16x2_cosine_200e_scannet_seg-3d-20class.py checkpoints/pointnet2_ssg_16x2_cosine_200e_scannet_seg-3d-20class_20210514_143644-ee73704a.pth
```
2 changes: 1 addition & 1 deletion docs/1_exist_data_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ According to the [Linear Scaling Rule](https://arxiv.org/abs/1706.02677), you ne
python tools/train.py ${CONFIG_FILE} [optional arguments]
```

If you want to specify the working directory in the command, you can add an argument `--work_dir ${YOUR_WORK_DIR}`.
If you want to specify the working directory in the command, you can add an argument `--work-dir ${YOUR_WORK_DIR}`.

### Train with multiple GPUs

Expand Down
4 changes: 2 additions & 2 deletions docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,13 @@ More demos about single/multi-modality and indoor/outdoor 3D detection can be fo
Here is an example of building the model and test given point clouds.

```python
from mmdet3d.apis import init_detector, inference_detector
from mmdet3d.apis import init_model, inference_detector

config_file = 'configs/votenet/votenet_8x8_scannet-3d-18class.py'
checkpoint_file = 'checkpoints/votenet_8x8_scannet-3d-18class_20200620_230238-2cea9c3a.pth'

# build the model from a config file and a checkpoint file
model = init_detector(config_file, checkpoint_file, device='cuda:0')
model = init_model(config_file, checkpoint_file, device='cuda:0')

# test a single image and show the results
point_cloud = 'test.bin'
Expand Down
8 changes: 4 additions & 4 deletions mmdet3d/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from .inference import (convert_SyncBN, inference_detector,
inference_multi_modality_detector, init_detector,
show_result_meshlab)
inference_multi_modality_detector, inference_segmentor,
init_model, show_result_meshlab)
from .test import single_gpu_test
from .train import train_model

__all__ = [
'inference_detector', 'init_detector', 'single_gpu_test',
'inference_detector', 'init_model', 'single_gpu_test',
'show_result_meshlab', 'convert_SyncBN', 'train_model',
'inference_multi_modality_detector'
'inference_multi_modality_detector', 'inference_segmentor'
]
Loading

0 comments on commit 3bac800

Please sign in to comment.