-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Model] Scene Graph Extraction Model with GluonCV (dmlc#1260)
* add working scripts * add frcnn training script * remove redundent files * refactor validation computation, will optimize sgdet and training * validation finally finished * f-rcnn training * test reldn * rm file * update reldn training * data preprocess to h5 * temp * use coco json * fix conflict * new obj dataset for detection * update training * before cleanup * remove abundant files * add arg parse to train * cleanup code file * update * fix * add readme * add ipynb as demo * add demo pic * update readme * add demo script * improve paths * improve readme * add docstrings * fix args description * update readme * add models from s3 * update README Co-authored-by: Minjie Wang <[email protected]>
- Loading branch information
1 parent
ce93330
commit cbee427
Showing
23 changed files
with
3,163 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -147,3 +147,6 @@ cscope.* | |
*.swo | ||
*.un~ | ||
*~ | ||
|
||
# parameters | ||
*.params |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
# Scene Graph Extraction | ||
|
||
Scene graph extraction aims at not only detect objects in the given image, but also classify the relationships between pairs of them. | ||
|
||
This example reproduces [Graphical Contrastive Losses for Scene Graph Parsing](https://arxiv.org/abs/1903.02728), author's code can be found [here](https://github.com/NVIDIA/ContrastiveLosses4VRD). | ||
|
||
![DEMO](https://raw.githubusercontent.com/dmlc/web-data/master/dgl/examples/mxnet/scenegraph/old-couple-pred.png) | ||
|
||
## Results | ||
|
||
**VisualGenome** | ||
|
||
| Model | Backbone | mAP@50 | SGDET@20 | SGDET@50 | SGDET@100 | PHRCLS@20 | PHRCLS@50 |PHRCLS@100 | PREDCLS@20 | PREDCLS@50 | PREDCLS@100 | | ||
| :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | | ||
| RelDN, L0 | ResNet101 | 29.5 | 22.65 | 30.02 | 35.04 | 32.84 | 35.60 | 36.26 | 60.58 | 65.53 | 66.51 | | ||
|
||
## Preparation | ||
|
||
This implementation is based on GluonCV. Install GluonCV with | ||
|
||
``` | ||
pip install gluoncv --upgrade | ||
``` | ||
|
||
The implementation contains the following files: | ||
|
||
``` | ||
. | ||
|-- data | ||
| |-- dataloader.py | ||
| |-- __init__.py | ||
| |-- object.py | ||
| |-- prepare_visualgenome.py | ||
| `-- relation.py | ||
|-- demo_reldn.py | ||
|-- model | ||
| |-- faster_rcnn.py | ||
| |-- __init__.py | ||
| `-- reldn.py | ||
|-- README.md | ||
|-- train_faster_rcnn.py | ||
|-- train_faster_rcnn.sh | ||
|-- train_freq_prior.py | ||
|-- train_reldn.py | ||
|-- train_reldn.sh | ||
|-- utils | ||
| |-- build_graph.py | ||
| |-- __init__.py | ||
| |-- metric.py | ||
| |-- sampling.py | ||
| `-- viz.py | ||
|-- validate_reldn.py | ||
`-- validate_reldn.sh | ||
``` | ||
|
||
- The folder `data` contains the data preparation script, and definition of datasets for object detection and scene graph extraction. | ||
- The folder `model` contains model definition. | ||
- The folder `utils` contains helper functions for training, validation, and visualization. | ||
- The script `train_faster_rcnn.py` trains a Faster R-CNN model on VisualGenome dataset, and `train_faster_rcnn.sh` includes preset parameters. | ||
- The script `train_freq_prior.py` trains the frequency counts for RelDN model training. | ||
- The script `train_reldn.py` trains a RelDN model, and `train_reldn.sh` includes preset parameters. | ||
- The script `validate_reldn.py` validate the trained Faster R-CNN and RelDN models, and `validate_reldn.sh` includes preset parameters. | ||
- The script `demo_reldh.py` makes use of trained parameters and extract an scene graph from an arbitrary input image. | ||
|
||
Below are further steps on training your own models. Besides, we also provide pretrained model files for validation and demo: | ||
|
||
1. [Faster R-CNN Model for Object Detection](http://dgl-data/models/SceneGraph/faster_rcnn_resnet101_v1d_visualgenome.params) | ||
2. [RelDN Model](http://dgl-data/models/SceneGraph/reldn.params) | ||
3. [Faster R-CNN Model for Edge Feature](http://dgl-data/models/SceneGraph/detector_feature.params) | ||
|
||
## Data preparation | ||
|
||
We provide scripts to download and prepare the VisualGenome dataset. One can run with | ||
|
||
``` | ||
python data/prepare_visualgenome.py | ||
``` | ||
|
||
## Object Detector | ||
|
||
First one need to train the object detection model on VisualGenome. | ||
|
||
``` | ||
bash train_faster_rcnn.sh | ||
``` | ||
|
||
It runs for about 20 hours on a machine with 64 CPU cores and 8 V100 GPUs. | ||
|
||
## Training RelDN | ||
|
||
With a trained Faster R-CNN model, one can start the training of RelDN model by | ||
|
||
``` | ||
bash train_reldn.sh | ||
``` | ||
|
||
It runs for about 2 days with one single GPU and 8 CPU cores. | ||
|
||
## Validate RelDN | ||
|
||
After the training, one can evaluate the results with multiple commonly-used metrics: | ||
|
||
``` | ||
bash validate_reldn.sh | ||
``` | ||
|
||
## Demo | ||
|
||
We provide a demo script of running the model with real-world pictures. Be aware that you need trained model to generate meaningful results from the demo, otherwise the script will download the pre-trained model automatically. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .object import * | ||
from .relation import * | ||
from .dataloader import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
"""DataLoader utils.""" | ||
import dgl | ||
from mxnet import nd | ||
from gluoncv.data.batchify import Pad | ||
|
||
def dgl_mp_batchify_fn(data): | ||
if isinstance(data[0], tuple): | ||
data = zip(*data) | ||
return [dgl_mp_batchify_fn(i) for i in data] | ||
|
||
for dt in data: | ||
if dt is not None: | ||
if isinstance(dt, dgl.DGLGraph): | ||
return [d for d in data if isinstance(d, dgl.DGLGraph)] | ||
elif isinstance(dt, nd.NDArray): | ||
pad = Pad(axis=(1, 2), num_shards=1, ret_length=False) | ||
data_list = [dt for dt in data if dt is not None] | ||
return pad(data_list) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
"""Pascal VOC object detection dataset.""" | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
import os | ||
import logging | ||
import warnings | ||
import json | ||
import pickle | ||
import numpy as np | ||
import mxnet as mx | ||
from gluoncv.data import COCODetection | ||
from collections import Counter | ||
|
||
class VGObject(COCODetection): | ||
CLASSES = ["airplane", "animal", "arm", "bag", "banana", "basket", "beach", | ||
"bear", "bed", "bench", "bike", "bird", "board", "boat", "book", | ||
"boot", "bottle", "bowl", "box", "boy", "branch", "building", "bus", | ||
"cabinet", "cap", "car", "cat", "chair", "child", "clock", "coat", | ||
"counter", "cow", "cup", "curtain", "desk", "dog", "door", "drawer", | ||
"ear", "elephant", "engine", "eye", "face", "fence", "finger", "flag", | ||
"flower", "food", "fork", "fruit", "giraffe", "girl", "glass", "glove", | ||
"guy", "hair", "hand", "handle", "hat", "head", "helmet", "hill", | ||
"horse", "house", "jacket", "jean", "kid", "kite", "lady", "lamp", | ||
"laptop", "leaf", "leg", "letter", "light", "logo", "man", "men", | ||
"motorcycle", "mountain", "mouth", "neck", "nose", "number", "orange", | ||
"pant", "paper", "paw", "people", "person", "phone", "pillow", "pizza", | ||
"plane", "plant", "plate", "player", "pole", "post", "pot", "racket", | ||
"railing", "rock", "roof", "room", "screen", "seat", "sheep", "shelf", | ||
"shirt", "shoe", "short", "sidewalk", "sign", "sink", "skateboard", | ||
"ski", "skier", "sneaker", "snow", "sock", "stand", "street", | ||
"surfboard", "table", "tail", "tie", "tile", "tire", "toilet", | ||
"towel", "tower", "track", "train", "tree", "truck", "trunk", | ||
"umbrella", "vase", "vegetable", "vehicle", "wave", "wheel", | ||
"window", "windshield", "wing", "wire", "woman", "zebra"] | ||
def __init__(self, **kwargs): | ||
super(VGObject, self).__init__(**kwargs) | ||
|
||
@property | ||
def annotation_dir(self): | ||
return '' | ||
|
||
def _parse_image_path(self, entry): | ||
dirname = 'VG_100K' | ||
filename = entry['file_name'] | ||
abs_path = os.path.join(self._root, dirname, filename) | ||
return abs_path |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
"""Prepare Visual Genome datasets""" | ||
import os | ||
import shutil | ||
import argparse | ||
import zipfile | ||
import random | ||
import json | ||
import tqdm | ||
import pickle | ||
from gluoncv.utils import download, makedirs | ||
|
||
_TARGET_DIR = os.path.expanduser('~/.mxnet/datasets/visualgenome') | ||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser( | ||
description='Initialize Visual Genome dataset.', | ||
epilog='Example: python visualgenome.py --download-dir ~/visualgenome', | ||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||
parser.add_argument('--download-dir', type=str, default='~/visualgenome/', | ||
help='dataset directory on disk') | ||
parser.add_argument('--no-download', action='store_true', help='disable automatic download if set') | ||
parser.add_argument('--overwrite', action='store_true', help='overwrite downloaded files if set, in case they are corrupted') | ||
args = parser.parse_args() | ||
return args | ||
|
||
def download_vg(path, overwrite=False): | ||
_DOWNLOAD_URLS = [ | ||
('https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip', | ||
'a055367f675dd5476220e9b93e4ca9957b024b94'), | ||
('https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip', | ||
'2add3aab77623549e92b7f15cda0308f50b64ecf'), | ||
] | ||
makedirs(path) | ||
for url, checksum in _DOWNLOAD_URLS: | ||
filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum) | ||
# extract | ||
if filename.endswith('zip'): | ||
with zipfile.ZipFile(filename) as zf: | ||
zf.extractall(path=path) | ||
# move all images into folder `VG_100K` | ||
vg_100k_path = os.path.join(path, 'VG_100K') | ||
vg_100k_2_path = os.path.join(path, 'VG_100K_2') | ||
files_2 = os.listdir(vg_100k_2_path) | ||
for fl in files_2: | ||
shutil.move(os.path.join(vg_100k_2_path, fl), | ||
os.path.join(vg_100k_path, fl)) | ||
|
||
def download_json(path, overwrite=False): | ||
url = 'https://data.dgl.ai/dataset/vg.zip' | ||
output = 'vg.zip' | ||
download(url, path=path) | ||
with zipfile.ZipFile(output) as zf: | ||
zf.extractall(path=path) | ||
json_path = os.path.join(path, 'vg') | ||
json_files = os.listdir(json_path) | ||
for fl in json_files: | ||
shutil.move(os.path.join(json_path, fl), | ||
os.path.join(path, fl)) | ||
os.rmdir(json_path) | ||
|
||
if __name__ == '__main__': | ||
args = parse_args() | ||
path = os.path.expanduser(args.download_dir) | ||
if not os.path.isdir(path): | ||
if args.no_download: | ||
raise ValueError(('{} is not a valid directory, make sure it is present.' | ||
' Or you should not disable "--no-download" to grab it'.format(path))) | ||
else: | ||
download_vg(path, overwrite=args.overwrite) | ||
download_json(path, overwrite=args.overwrite) | ||
|
||
# make symlink | ||
makedirs(os.path.expanduser('~/.mxnet/datasets')) | ||
if os.path.isdir(_TARGET_DIR): | ||
os.rmdir(_TARGET_DIR) | ||
os.symlink(path, _TARGET_DIR) |
Oops, something went wrong.