forked from facebookresearch/detr
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Detectron2 wrapper (facebookresearch#103)
- Loading branch information
alcinos
authored
Jun 28, 2020
1 parent
10a2c75
commit 3673ffe
Showing
9 changed files
with
692 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
Detectron2 wrapper for DETR | ||
======= | ||
|
||
We provide a Detectron2 wrapper for DETR, thus providing a way to better integrate it in the existing detection ecosystem. It can be used for example to easily leverage datasets or backbones provided in Detectron2. | ||
|
||
This wrapper currently supports only box detection, and is intended to be as close as possible to the original implementation, and we checked that it indeed match the results. Some notable facts and caveats: | ||
- The data augmentation matches DETR's original data augmentation. This required patching the RandomCrop augmentation from Detectron2, so you'll need a version from the master branch from June 24th 2020 or more recent. | ||
- To match DETR's original backbone initialization, we use the weights of a ResNet50 trained on imagenet using torchvision. This network uses a different pixel mean and std than most of the backbones available in Detectron2 by default, so extra care must be taken when switching to another one. Note that no other torchvision models are available in Detectron2 as of now, though it may change in the future. | ||
- The gradient clipping mode is "full_model", which is not the default in Detectron2. | ||
|
||
# Usage | ||
|
||
To install Detectron2, please follow the [official installation instructions](https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md). | ||
|
||
## Evaluating a model | ||
|
||
For convenience, we provide a conversion script to convert models trained by the main DETR training loop into the format of this wrapper. To download and convert the main Resnet50 model, simply do: | ||
|
||
``` | ||
python converter.py --source_model https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth --output_model converted_model.pth | ||
``` | ||
|
||
You can then evaluate it using: | ||
``` | ||
python train_net.py --eval-only --config configs/detr_256_6_6_torchvision.yaml MODEL.WEIGHTS "converted_model.pth" | ||
``` | ||
|
||
|
||
## Training | ||
|
||
To train DETR on a single node with 8 gpus, simply use: | ||
``` | ||
python train_net.py --config configs/detr_256_6_6_torchvision.yaml --num-gpus 8 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
MODEL: | ||
META_ARCHITECTURE: "Detr" | ||
WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl" | ||
PIXEL_MEAN: [123.675, 116.280, 103.530] | ||
PIXEL_STD: [58.395, 57.120, 57.375] | ||
MASK_ON: False | ||
RESNETS: | ||
DEPTH: 50 | ||
STRIDE_IN_1X1: False | ||
OUT_FEATURES: ["res2", "res3", "res4", "res5"] | ||
DETR: | ||
GIOU_WEIGHT: 2.0 | ||
L1_WEIGHT: 5.0 | ||
NUM_OBJECT_QUERIES: 100 | ||
DATASETS: | ||
TRAIN: ("coco_2017_train",) | ||
TEST: ("coco_2017_val",) | ||
SOLVER: | ||
IMS_PER_BATCH: 64 | ||
BASE_LR: 0.0001 | ||
STEPS: (369600,) | ||
MAX_ITER: 554400 | ||
WARMUP_FACTOR: 1.0 | ||
WARMUP_ITERS: 10 | ||
WEIGHT_DECAY: 0.0001 | ||
OPTIMIZER: "ADAMW" | ||
BACKBONE_MULTIPLIER: 0.1 | ||
CLIP_GRADIENTS: | ||
ENABLED: True | ||
CLIP_TYPE: "full_model" | ||
CLIP_VALUE: 0.01 | ||
NORM_TYPE: 2.0 | ||
INPUT: | ||
MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800) | ||
CROP: | ||
ENABLED: True | ||
TYPE: "absolute_range" | ||
SIZE: (384, 600) | ||
FORMAT: "RGB" | ||
TEST: | ||
EVAL_PERIOD: 4000 | ||
DATALOADER: | ||
FILTER_EMPTY_ANNOTATIONS: False | ||
NUM_WORKERS: 4 | ||
VERSION: 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
""" | ||
Helper script to convert models trained with the main version of DETR to be used with the Detectron2 version. | ||
""" | ||
import json | ||
import argparse | ||
|
||
import numpy as np | ||
import torch | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser("D2 model converter") | ||
|
||
parser.add_argument("--source_model", default="", type=str, help="Path or url to the DETR model to convert") | ||
parser.add_argument("--output_model", default="", type=str, help="Path where to save the converted model") | ||
return parser.parse_args() | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
|
||
# D2 expects contiguous classes, so we need to remap the 92 classes from DETR | ||
# fmt: off | ||
coco_idx = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, | ||
27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, | ||
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, | ||
78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90, 91] | ||
# fmt: on | ||
|
||
coco_idx = np.array(coco_idx) | ||
|
||
if args.source_model.startswith("https"): | ||
checkpoint = torch.hub.load_state_dict_from_url(args.source_model, map_location="cpu", check_hash=True) | ||
else: | ||
checkpoint = torch.load(args.source_model, map_location="cpu") | ||
model_to_convert = checkpoint["model"] | ||
|
||
model_converted = {} | ||
for k in model_to_convert.keys(): | ||
old_k = k | ||
if "backbone" in k: | ||
k = k.replace("backbone.0.body.", "") | ||
if "layer" not in k: | ||
k = "stem." + k | ||
for t in [1, 2, 3, 4]: | ||
k = k.replace(f"layer{t}", f"res{t + 1}") | ||
for t in [1, 2, 3]: | ||
k = k.replace(f"bn{t}", f"conv{t}.norm") | ||
k = k.replace("downsample.0", "shortcut") | ||
k = k.replace("downsample.1", "shortcut.norm") | ||
k = "backbone.0.backbone." + k | ||
k = "detr." + k | ||
print(old_k, "->", k) | ||
if "class_embed" in old_k: | ||
v = model_to_convert[old_k].detach() | ||
if v.shape[0] == 92: | ||
shape_old = v.shape | ||
model_converted[k] = v[coco_idx] | ||
print("Head conversion: changing shape from {} to {}".format(shape_old, model_converted[k].shape)) | ||
continue | ||
model_converted[k] = model_to_convert[old_k].detach() | ||
|
||
model_to_save = {"model": model_converted} | ||
torch.save(model_to_save, args.output_model) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
from .config import add_detr_config | ||
from .detr import Detr | ||
from .dataset_mapper import DetrDatasetMapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# -*- coding: utf-8 -*- | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
from detectron2.config import CfgNode as CN | ||
|
||
|
||
def add_detr_config(cfg): | ||
""" | ||
Add config for DETR. | ||
""" | ||
cfg.MODEL.DETR = CN() | ||
cfg.MODEL.DETR.NUM_CLASSES = 80 | ||
|
||
# LOSS | ||
cfg.MODEL.DETR.GIOU_WEIGHT = 2.0 | ||
cfg.MODEL.DETR.L1_WEIGHT = 5.0 | ||
cfg.MODEL.DETR.DEEP_SUPERVISION = True | ||
cfg.MODEL.DETR.NO_OBJECT_WEIGHT = 0.1 | ||
|
||
# TRANSFORMER | ||
cfg.MODEL.DETR.NHEADS = 8 | ||
cfg.MODEL.DETR.DROPOUT = 0.1 | ||
cfg.MODEL.DETR.DIM_FEEDFORWARD = 2048 | ||
cfg.MODEL.DETR.ENC_LAYERS = 6 | ||
cfg.MODEL.DETR.DEC_LAYERS = 6 | ||
cfg.MODEL.DETR.PRE_NORM = False | ||
cfg.MODEL.DETR.PASS_POS_AND_QUERY = True | ||
|
||
cfg.MODEL.DETR.HIDDEN_DIM = 256 | ||
cfg.MODEL.DETR.NUM_OBJECT_QUERIES = 100 | ||
|
||
cfg.SOLVER.OPTIMIZER = "ADAMW" | ||
cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
import copy | ||
import logging | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from detectron2.data import detection_utils as utils | ||
from detectron2.data import transforms as T | ||
from detectron2.data.transforms import TransformGen | ||
|
||
__all__ = ["DetrDatasetMapper"] | ||
|
||
|
||
def build_transform_gen(cfg, is_train): | ||
""" | ||
Create a list of :class:`TransformGen` from config. | ||
Returns: | ||
list[TransformGen] | ||
""" | ||
if is_train: | ||
min_size = cfg.INPUT.MIN_SIZE_TRAIN | ||
max_size = cfg.INPUT.MAX_SIZE_TRAIN | ||
sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING | ||
else: | ||
min_size = cfg.INPUT.MIN_SIZE_TEST | ||
max_size = cfg.INPUT.MAX_SIZE_TEST | ||
sample_style = "choice" | ||
if sample_style == "range": | ||
assert len(min_size) == 2, "more than 2 ({}) min_size(s) are provided for ranges".format(len(min_size)) | ||
|
||
logger = logging.getLogger(__name__) | ||
tfm_gens = [] | ||
if is_train: | ||
tfm_gens.append(T.RandomFlip()) | ||
tfm_gens.append(T.ResizeShortestEdge(min_size, max_size, sample_style)) | ||
if is_train: | ||
logger.info("TransformGens used in training: " + str(tfm_gens)) | ||
return tfm_gens | ||
|
||
|
||
class DetrDatasetMapper: | ||
""" | ||
A callable which takes a dataset dict in Detectron2 Dataset format, | ||
and map it into a format used by DETR. | ||
The callable currently does the following: | ||
1. Read the image from "file_name" | ||
2. Applies geometric transforms to the image and annotation | ||
3. Find and applies suitable cropping to the image and annotation | ||
4. Prepare image and annotation to Tensors | ||
""" | ||
|
||
def __init__(self, cfg, is_train=True): | ||
if cfg.INPUT.CROP.ENABLED and is_train: | ||
self.crop_gen = [ | ||
T.ResizeShortestEdge([400, 500, 600], sample_style="choice"), | ||
T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE), | ||
] | ||
else: | ||
self.crop_gen = None | ||
|
||
assert not cfg.MODEL.MASK_ON, "Mask is not supported" | ||
|
||
self.tfm_gens = build_transform_gen(cfg, is_train) | ||
logging.getLogger(__name__).info( | ||
"Full TransformGens used in training: {}, crop: {}".format(str(self.tfm_gens), str(self.crop_gen)) | ||
) | ||
|
||
self.img_format = cfg.INPUT.FORMAT | ||
self.is_train = is_train | ||
|
||
def __call__(self, dataset_dict): | ||
""" | ||
Args: | ||
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. | ||
Returns: | ||
dict: a format that builtin models in detectron2 accept | ||
""" | ||
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below | ||
image = utils.read_image(dataset_dict["file_name"], format=self.img_format) | ||
utils.check_image_size(dataset_dict, image) | ||
|
||
if self.crop_gen is None: | ||
image, transforms = T.apply_transform_gens(self.tfm_gens, image) | ||
else: | ||
if np.random.rand() > 0.5: | ||
image, transforms = T.apply_transform_gens(self.tfm_gens, image) | ||
else: | ||
image, transforms = T.apply_transform_gens( | ||
self.tfm_gens[:-1] + self.crop_gen + self.tfm_gens[-1:], image | ||
) | ||
|
||
image_shape = image.shape[:2] # h, w | ||
|
||
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, | ||
# but not efficient on large generic data structures due to the use of pickle & mp.Queue. | ||
# Therefore it's important to use torch.Tensor. | ||
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) | ||
|
||
if not self.is_train: | ||
# USER: Modify this if you want to keep them for some reason. | ||
dataset_dict.pop("annotations", None) | ||
return dataset_dict | ||
|
||
if "annotations" in dataset_dict: | ||
# USER: Modify this if you want to keep them for some reason. | ||
for anno in dataset_dict["annotations"]: | ||
anno.pop("segmentation", None) | ||
anno.pop("keypoints", None) | ||
|
||
# USER: Implement additional transformations if you have other types of data | ||
annos = [ | ||
utils.transform_instance_annotations(obj, transforms, image_shape) | ||
for obj in dataset_dict.pop("annotations") | ||
if obj.get("iscrowd", 0) == 0 | ||
] | ||
instances = utils.annotations_to_instances(annos, image_shape) | ||
dataset_dict["instances"] = utils.filter_empty_instances(instances) | ||
return dataset_dict |
Oops, something went wrong.