f-BRS: Rethinking Backpropagating Refinement for Interactive Segmentation [Paper] [PyTorch] [MXNet] [Video]
This repository provides code for training and testing state-of-the-art models for interactive segmentation with the official PyTorch implementation of the following paper:
f-BRS: Rethinking Backpropagating Refinement for Interactive Segmentation
Konstantin Sofiiuk, Ilia Petrov, Olga Barinova, Anton Konushin
Samsung AI Center Moscow
https://arxiv.org/abs/2001.10331
Please see the video below explaining how our algorithm works:
We also have full MXNet implementation of our algorithm, you can check mxnet branch.
- [2021-02-16] We have presented a new paper (+code) on interactive segmentation: Reviving Iterative Training with Mask Guidance for Interactive Segmentation. A simpler approach with new SoTA results and without any test-time optimization techniques.
This framework is built using Python 3.6 and relies on the PyTorch 1.4.0+. The following command installs all necessary packages:
pip3 install -r requirements.txt
You can also use our Dockerfile to build a container with configured environment.
If you want to run training or testing, you must configure the paths to the datasets in config.yml (SBD for training and testing, GrabCut, Berkeley, DAVIS and COCO_MVal for testing only).
The GUI is based on TkInter library and it's Python bindings. You can try an interactive demo with any of provided models (see section below). Our scripts automatically detect the architecture of the loaded model, just specify the path to the corresponding checkpoint.
Examples of the script usage:
# This command runs interactive demo with ResNet-34 model from cfg.INTERACTIVE_MODELS_PATH on GPU with id=0
# --checkpoint can be relative to cfg.INTERACTIVE_MODELS_PATH or absolute path to the checkpoint
python3 demo.py --checkpoint=resnet34_dh128_sbd --gpu=0
# This command runs interactive demo with ResNet-34 model from /home/demo/fBRS/weights/
# If you also do not have a lot of GPU memory, you can reduce --limit-longest-size (default=800)
python3 demo.py --checkpoint=/home/demo/fBRS/weights/resnet34_dh128_sbd --limit-longest-size=400
# You can try the demo in CPU only mode
python3 demo.py --checkpoint=resnet34_dh128_sbd --cpu
You can also use the docker image to run the demo. For this you need to activate X-host connection and then run the container with some additional flags:
# activate xhost
xhost +
docker run -v "$PWD":/tmp/ \
-v /tmp/.X11-unix:/tmp/.X11-unix \
-e DISPLAY=$DISPLAY <id-or-tag-docker-built-image> \
python3 demo.py --checkpoint resnet34_dh128_sbd --cpu
- press left and right mouse buttons for positive and negative clicks, respectively;
- scroll wheel to zoom in and out;
- hold right mouse button and drag to move around an image (you can also use arrows and WASD);
- press space to finish the current object;
- when multiple files are open, pressing the left arrow key displays the previous image, and pressing the right arrow key displays the next image;
- use Ctrl+S to save the annotation you're currently editing ("original file name".png).
- ZoomIn (can be turned on/off using the checkbox)
- Skip clicks - the number of clicks to skip before using ZoomIn.
- Target size - ZoomIn crop is resized so its longer side matches this value (increase for large objects).
- Expand ratio - object bbox is rescaled with this ratio before crop.
- BRS parameters (BRS type can be changed using the dropdown menu)
- Network clicks - the number of first clicks that are included in the network's input. Subsequent clicks are processed only using BRS (NoBRS ignores this option).
- L-BFGS-B max iterations - the maximum number of function evaluation for each step of optimization in BRS (increase for better accuracy and longer computational time for each click).
- Visualisation parameters
- Prediction threshold slider adjusts the threshold for binarization of probability map for the current object.
- Alpha blending coefficient slider adjusts the intensity of all predicted masks.
- Visualisation click radius slider adjusts the size of red and green dots depicting clicks.
We train all our models on SBD dataset and evaluate them on GrabCut, Berkeley, DAVIS, SBD and COCO_MVal datasets. We additionally provide the results of models that trained on combination of COCO and LVIS datasets.
Berkeley dataset consists of 100 instances (96 unique images) provided by [K. McGuinness, 2010]. We use the same 345 images from DAVIS dataset for testing as [WD Jang, 2019], ground-truth mask for each image is a union of all objects' masks. For testing on SBD dataset we evaluate our algorithm on every instance in the test set separately following the protocol of [WD Jang, 2019].
To construct COCO_MVal dataset we sample 800 object instances from the validation set of COCO 2017. Specifically, we sample 10 unique instances from each of the 80 categories. The only exception is the toaster object class, which has only 9 unique instances in instances_val2017 annotation. So to get 800 masks one of the classes contains 11 objects. We provide this dataset for downloading so that everyone can reproduce our results.
Dataset | Description | Download Link |
---|---|---|
SBD | 8498 images with 20172 instances for training and 2857 images with 6671 instances for testing |
official site |
Grab Cut | 50 images with one object each | GrabCut.zip (11 MB) |
Berkeley | 96 images with 100 instances | Berkeley.zip (7 MB) |
DAVIS | 345 images with one object each | DAVIS.zip (43 MB) |
COCO_MVal | 800 images with 800 instances | COCO_MVal.zip (127 MB) |
Don't forget to change the paths to the datasets in config.yml after downloading and unpacking.
We provide pretrained models with different backbones for interactive segmentation. The evaluation results are different from the ones presented in our paper, because we have retrained all models on the new codebase presented in this repository. We greatly accelerated the inference of the RGB-BRS algorithm - now it works from 2.5 to 4 times faster on SBD dataset compared to the timings given in the paper. Nevertheless, the new results sometimes are even better.
Note that all ResNet models were trained using MXNet branch and then converted to PyTorch (they have equivalent results). We provide the script that was used to convert the models. HRNet models were trained using PyTorch.
You can find model weights and test results in the tables below:
Backbone | Train Dataset | Link |
---|---|---|
ResNet-34 | SBD | resnet34_dh128_sbd.pth (GitHub, 89 MB) |
ResNet-50 | SBD | resnet50_dh128_sbd.pth (GitHub, 120 MB) |
ResNet-101 | SBD | resnet101_dh256_sbd.pth (GitHub, 223 MB) |
HRNetV2-W18+OCR | SBD | hrnet18_ocr64_sbd.pth (GitHub, 39 MB) |
HRNetV2-W32+OCR | SBD | hrnet32_ocr128_sbd.pth (GitHub, 119 MB) |
ResNet-50 | COCO+LVIS | resnet50_dh128_lvis.pth (GitHub, 120 MB) |
HRNetV2-W32+OCR | COCO+LVIS | hrnet32_ocr128_lvis.pth (GitHub, 119 MB) |
Model | BRS Type |
GrabCut | Berkeley | DAVIS | SBD | COCO_MVal | |||||
---|---|---|---|---|---|---|---|---|---|---|---|
NoC 85% |
NoC 90% |
NoC 85% |
NoC 90% |
NoC 85% |
NoC 90% |
NoC 85% |
NoC 90% |
NoC 85% |
NoC 90% |
||
ResNet-34 (SBD) |
RGB-BRS | 2.04 | 2.50 | 2.22 | 4.49 | 5.34 | 7.91 | 4.19 | 6.83 | 4.16 | 5.52 |
f-BRS-B | 2.06 | 2.48 | 2.40 | 4.17 | 5.34 | 7.73 | 4.47 | 7.28 | 4.31 | 5.79 | |
ResNet-50 (SBD) |
RGB-BRS | 2.16 | 2.56 | 2.17 | 4.27 | 5.27 | 7.51 | 4.00 | 6.59 | 4.12 | 5.61 |
f-BRS-B | 2.20 | 2.64 | 2.17 | 4.22 | 5.44 | 7.81 | 4.55 | 7.45 | 4.31 | 6.26 | |
ResNet-101 (SBD) |
RGB-BRS | 2.10 | 2.46 | 2.34 | 3.91 | 5.19 | 7.23 | 3.78 | 6.28 | 3.98 | 5.45 |
f-BRS-B | 2.30 | 2.68 | 2.61 | 4.22 | 5.32 | 7.35 | 4.20 | 7.10 | 4.11 | 5.91 | |
HRNet-W18+OCR (SBD) |
RGB-BRS | 1.68 | 1.94 | 1.99 | 3.81 | 5.49 | 7.98 | 4.19 | 6.84 | 3.62 | 5.04 |
f-BRS-B | 1.86 | 2.18 | 2.07 | 3.96 | 5.62 | 8.08 | 4.70 | 7.65 | 3.87 | 5.57 | |
HRNet-W32+OCR (SBD) |
RGB-BRS | 1.80 | 2.16 | 2.00 | 3.58 | 5.40 | 7.59 | 3.87 | 6.33 | 3.61 | 5.12 |
f-BRS-B | 1.78 | 2.16 | 2.13 | 3.69 | 5.54 | 7.62 | 4.31 | 7.08 | 3.82 | 5.44 | |
ResNet-50 (COCO+LVIS) |
RGB-BRS | 1.54 | 1.76 | 1.56 | 2.70 | 4.93 | 6.22 | 4.04 | 6.85 | 2.41 | 3.47 |
f-BRS-B | 1.52 | 1.74 | 1.56 | 2.61 | 4.94 | 6.36 | 4.29 | 7.20 | 2.34 | 3.43 | |
HRNet-W32+OCR (COCO+LVIS) |
RGB-BRS | 1.54 | 1.60 | 1.63 | 2.59 | 5.06 | 6.34 | 4.18 | 6.96 | 2.38 | 3.55 |
f-BRS-B | 1.54 | 1.69 | 1.64 | 2.44 | 5.17 | 6.50 | 4.37 | 7.26 | 2.35 | 3.44 |
We provide the script to test all the presented models in all possible configurations on GrabCut, Berkeley, DAVIS, COCO_MVal and SBD datasets. To test a model, you should download its weights and put it in ./weights
folder (you can change this path in the config.yml, see INTERACTIVE_MODELS_PATH
variable). To test any of our models, just specify the path to the corresponding checkpoint. Our scripts automatically detect the architecture of the loaded model.
The following command runs the model evaluation (other options are displayed using '-h'):
python3 scripts/evaluate_model.py <brs-mode> --checkpoint=<checkpoint-name>
Examples of the script usage:
# This command evaluates ResNet-34 model in f-BRS-B mode on all Datasets.
python3 scripts/evaluate_model.py f-BRS-B --checkpoint=resnet34_dh128_sbd
# This command evaluates HRNetV2-W32+OCR model in f-BRS-B mode on all Datasets.
python3 scripts/evaluate_model.py f-BRS-B --checkpoint=hrnet32_ocr128_sbd
# This command evaluates ResNet-50 model in RGB-BRS mode on GrabCut and Berkeley datasets.
python3 scripts/evaluate_model.py RGB-BRS --checkpoint=resnet50_dh128_sbd --datasets=GrabCut,Berkeley
# This command evaluates ResNet-101 model in DistMap-BRS mode on DAVIS dataset.
python3 scripts/evaluate_model.py DistMap-BRS --checkpoint=resnet101_dh256_sbd --datasets=DAVIS
You can also interactively experiment with our models using test_any_model.ipynb Jupyter notebook.
We provide the scripts for training our models on SBD dataset. You can start training with the following commands:
# ResNet-34 model
python3 train.py models/sbd/r34_dh128.py --gpus=0,1 --workers=4 --exp-name=first-try
# ResNet-50 model
python3 train.py models/sbd/r50_dh128.py --gpus=0,1 --workers=4 --exp-name=first-try
# ResNet-101 model
python3 train.py models/sbd/r101_dh256.py --gpus=0,1,2,3 --workers=6 --exp-name=first-try
# HRNetV2-W32+OCR model
python3 train.py models/sbd/hrnet32_ocr128.py --gpus=0,1 --workers=4 --exp-name=first-try
For each experiment, a separate folder is created in the ./experiments
with Tensorboard logs, text logs, visualization and model's checkpoints. You can specify another path in the config.yml (see EXPS_PATH
variable).
Please note that we have trained ResNet-34 and ResNet-50 models on 2 GPUs, ResNet-101 on 4 GPUs (we used Nvidia Tesla P40 for training). If you are going to train models in different GPUs configuration, you will need to set a different batch size. You can specify batch size using the command line argument --batch-size
or change the default value in model script.
We used pre-trained HRNetV2 models from the official repository. If you want to train interactive segmentation with these models, you need to download weights and specify the paths to them in config.yml.
The code is released under the MPL 2.0 License. MPL is a copyleft license that is easy to comply with. You must make the source code for any of your changes available under MPL, but you can combine the MPL software with proprietary code, as long as you keep the MPL code in separate files.
If you find this work is useful for your research, please cite our paper:
@inproceedings{fbrs2020,
title={f-brs: Rethinking backpropagating refinement for interactive segmentation},
author={Sofiiuk, Konstantin and Petrov, Ilia and Barinova, Olga and Konushin, Anton},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={8623--8632},
year={2020}
}