Skip to content

Commit

Permalink
[Model] Scene Graph Extraction Model with GluonCV (dmlc#1260)
Browse files Browse the repository at this point in the history
* add working scripts

* add frcnn training script

* remove redundent files

* refactor validation computation, will optimize sgdet and training

* validation finally finished

* f-rcnn training

* test reldn

* rm file

* update reldn training

* data preprocess to h5

* temp

* use coco json

* fix conflict

* new obj dataset for detection

* update training

* before cleanup

* remove abundant files

* add arg parse to train

* cleanup code file

* update

* fix

* add readme

* add ipynb as demo

* add demo pic

* update readme

* add demo script

* improve paths

* improve readme

* add docstrings

* fix args description

* update readme

* add models from s3

* update README

Co-authored-by: Minjie Wang <[email protected]>
  • Loading branch information
hetong007 and jermainewang authored Mar 5, 2020
1 parent ce93330 commit cbee427
Show file tree
Hide file tree
Showing 23 changed files with 3,163 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,6 @@ cscope.*
*.swo
*.un~
*~

# parameters
*.params
109 changes: 109 additions & 0 deletions examples/mxnet/scenegraph/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Scene Graph Extraction

Scene graph extraction aims at not only detect objects in the given image, but also classify the relationships between pairs of them.

This example reproduces [Graphical Contrastive Losses for Scene Graph Parsing](https://arxiv.org/abs/1903.02728), author's code can be found [here](https://github.com/NVIDIA/ContrastiveLosses4VRD).

![DEMO](https://raw.githubusercontent.com/dmlc/web-data/master/dgl/examples/mxnet/scenegraph/old-couple-pred.png)

## Results

**VisualGenome**

| Model | Backbone | mAP@50 | SGDET@20 | SGDET@50 | SGDET@100 | PHRCLS@20 | PHRCLS@50 |PHRCLS@100 | PREDCLS@20 | PREDCLS@50 | PREDCLS@100 |
| :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- |
| RelDN, L0 | ResNet101 | 29.5 | 22.65 | 30.02 | 35.04 | 32.84 | 35.60 | 36.26 | 60.58 | 65.53 | 66.51 |

## Preparation

This implementation is based on GluonCV. Install GluonCV with

```
pip install gluoncv --upgrade
```

The implementation contains the following files:

```
.
|-- data
| |-- dataloader.py
| |-- __init__.py
| |-- object.py
| |-- prepare_visualgenome.py
| `-- relation.py
|-- demo_reldn.py
|-- model
| |-- faster_rcnn.py
| |-- __init__.py
| `-- reldn.py
|-- README.md
|-- train_faster_rcnn.py
|-- train_faster_rcnn.sh
|-- train_freq_prior.py
|-- train_reldn.py
|-- train_reldn.sh
|-- utils
| |-- build_graph.py
| |-- __init__.py
| |-- metric.py
| |-- sampling.py
| `-- viz.py
|-- validate_reldn.py
`-- validate_reldn.sh
```

- The folder `data` contains the data preparation script, and definition of datasets for object detection and scene graph extraction.
- The folder `model` contains model definition.
- The folder `utils` contains helper functions for training, validation, and visualization.
- The script `train_faster_rcnn.py` trains a Faster R-CNN model on VisualGenome dataset, and `train_faster_rcnn.sh` includes preset parameters.
- The script `train_freq_prior.py` trains the frequency counts for RelDN model training.
- The script `train_reldn.py` trains a RelDN model, and `train_reldn.sh` includes preset parameters.
- The script `validate_reldn.py` validate the trained Faster R-CNN and RelDN models, and `validate_reldn.sh` includes preset parameters.
- The script `demo_reldh.py` makes use of trained parameters and extract an scene graph from an arbitrary input image.

Below are further steps on training your own models. Besides, we also provide pretrained model files for validation and demo:

1. [Faster R-CNN Model for Object Detection](http://dgl-data/models/SceneGraph/faster_rcnn_resnet101_v1d_visualgenome.params)
2. [RelDN Model](http://dgl-data/models/SceneGraph/reldn.params)
3. [Faster R-CNN Model for Edge Feature](http://dgl-data/models/SceneGraph/detector_feature.params)

## Data preparation

We provide scripts to download and prepare the VisualGenome dataset. One can run with

```
python data/prepare_visualgenome.py
```

## Object Detector

First one need to train the object detection model on VisualGenome.

```
bash train_faster_rcnn.sh
```

It runs for about 20 hours on a machine with 64 CPU cores and 8 V100 GPUs.

## Training RelDN

With a trained Faster R-CNN model, one can start the training of RelDN model by

```
bash train_reldn.sh
```

It runs for about 2 days with one single GPU and 8 CPU cores.

## Validate RelDN

After the training, one can evaluate the results with multiple commonly-used metrics:

```
bash validate_reldn.sh
```

## Demo

We provide a demo script of running the model with real-world pictures. Be aware that you need trained model to generate meaningful results from the demo, otherwise the script will download the pre-trained model automatically.
3 changes: 3 additions & 0 deletions examples/mxnet/scenegraph/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .object import *
from .relation import *
from .dataloader import *
18 changes: 18 additions & 0 deletions examples/mxnet/scenegraph/data/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""DataLoader utils."""
import dgl
from mxnet import nd
from gluoncv.data.batchify import Pad

def dgl_mp_batchify_fn(data):
if isinstance(data[0], tuple):
data = zip(*data)
return [dgl_mp_batchify_fn(i) for i in data]

for dt in data:
if dt is not None:
if isinstance(dt, dgl.DGLGraph):
return [d for d in data if isinstance(d, dgl.DGLGraph)]
elif isinstance(dt, nd.NDArray):
pad = Pad(axis=(1, 2), num_shards=1, ret_length=False)
data_list = [dt for dt in data if dt is not None]
return pad(data_list)
46 changes: 46 additions & 0 deletions examples/mxnet/scenegraph/data/object.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Pascal VOC object detection dataset."""
from __future__ import absolute_import
from __future__ import division
import os
import logging
import warnings
import json
import pickle
import numpy as np
import mxnet as mx
from gluoncv.data import COCODetection
from collections import Counter

class VGObject(COCODetection):
CLASSES = ["airplane", "animal", "arm", "bag", "banana", "basket", "beach",
"bear", "bed", "bench", "bike", "bird", "board", "boat", "book",
"boot", "bottle", "bowl", "box", "boy", "branch", "building", "bus",
"cabinet", "cap", "car", "cat", "chair", "child", "clock", "coat",
"counter", "cow", "cup", "curtain", "desk", "dog", "door", "drawer",
"ear", "elephant", "engine", "eye", "face", "fence", "finger", "flag",
"flower", "food", "fork", "fruit", "giraffe", "girl", "glass", "glove",
"guy", "hair", "hand", "handle", "hat", "head", "helmet", "hill",
"horse", "house", "jacket", "jean", "kid", "kite", "lady", "lamp",
"laptop", "leaf", "leg", "letter", "light", "logo", "man", "men",
"motorcycle", "mountain", "mouth", "neck", "nose", "number", "orange",
"pant", "paper", "paw", "people", "person", "phone", "pillow", "pizza",
"plane", "plant", "plate", "player", "pole", "post", "pot", "racket",
"railing", "rock", "roof", "room", "screen", "seat", "sheep", "shelf",
"shirt", "shoe", "short", "sidewalk", "sign", "sink", "skateboard",
"ski", "skier", "sneaker", "snow", "sock", "stand", "street",
"surfboard", "table", "tail", "tie", "tile", "tire", "toilet",
"towel", "tower", "track", "train", "tree", "truck", "trunk",
"umbrella", "vase", "vegetable", "vehicle", "wave", "wheel",
"window", "windshield", "wing", "wire", "woman", "zebra"]
def __init__(self, **kwargs):
super(VGObject, self).__init__(**kwargs)

@property
def annotation_dir(self):
return ''

def _parse_image_path(self, entry):
dirname = 'VG_100K'
filename = entry['file_name']
abs_path = os.path.join(self._root, dirname, filename)
return abs_path
76 changes: 76 additions & 0 deletions examples/mxnet/scenegraph/data/prepare_visualgenome.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Prepare Visual Genome datasets"""
import os
import shutil
import argparse
import zipfile
import random
import json
import tqdm
import pickle
from gluoncv.utils import download, makedirs

_TARGET_DIR = os.path.expanduser('~/.mxnet/datasets/visualgenome')

def parse_args():
parser = argparse.ArgumentParser(
description='Initialize Visual Genome dataset.',
epilog='Example: python visualgenome.py --download-dir ~/visualgenome',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--download-dir', type=str, default='~/visualgenome/',
help='dataset directory on disk')
parser.add_argument('--no-download', action='store_true', help='disable automatic download if set')
parser.add_argument('--overwrite', action='store_true', help='overwrite downloaded files if set, in case they are corrupted')
args = parser.parse_args()
return args

def download_vg(path, overwrite=False):
_DOWNLOAD_URLS = [
('https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip',
'a055367f675dd5476220e9b93e4ca9957b024b94'),
('https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip',
'2add3aab77623549e92b7f15cda0308f50b64ecf'),
]
makedirs(path)
for url, checksum in _DOWNLOAD_URLS:
filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum)
# extract
if filename.endswith('zip'):
with zipfile.ZipFile(filename) as zf:
zf.extractall(path=path)
# move all images into folder `VG_100K`
vg_100k_path = os.path.join(path, 'VG_100K')
vg_100k_2_path = os.path.join(path, 'VG_100K_2')
files_2 = os.listdir(vg_100k_2_path)
for fl in files_2:
shutil.move(os.path.join(vg_100k_2_path, fl),
os.path.join(vg_100k_path, fl))

def download_json(path, overwrite=False):
url = 'https://data.dgl.ai/dataset/vg.zip'
output = 'vg.zip'
download(url, path=path)
with zipfile.ZipFile(output) as zf:
zf.extractall(path=path)
json_path = os.path.join(path, 'vg')
json_files = os.listdir(json_path)
for fl in json_files:
shutil.move(os.path.join(json_path, fl),
os.path.join(path, fl))
os.rmdir(json_path)

if __name__ == '__main__':
args = parse_args()
path = os.path.expanduser(args.download_dir)
if not os.path.isdir(path):
if args.no_download:
raise ValueError(('{} is not a valid directory, make sure it is present.'
' Or you should not disable "--no-download" to grab it'.format(path)))
else:
download_vg(path, overwrite=args.overwrite)
download_json(path, overwrite=args.overwrite)

# make symlink
makedirs(os.path.expanduser('~/.mxnet/datasets'))
if os.path.isdir(_TARGET_DIR):
os.rmdir(_TARGET_DIR)
os.symlink(path, _TARGET_DIR)
Loading

0 comments on commit cbee427

Please sign in to comment.