forked from facebookresearch/detectron2
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Reviewed By: theschnitz Differential Revision: D28016190 fbshipit-source-id: ba305adb1624b09194c7d63b960c07917f044e05
- Loading branch information
1 parent
ee4bf1a
commit 543fd07
Showing
2 changed files
with
154 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
# An example config to train a mmdetection model using detectron2. | ||
|
||
from ..common.data.coco import dataloader | ||
from ..common.coco_schedule import lr_multiplier_1x as lr_multiplier | ||
from ..common.optim import SGD as optimizer | ||
from ..common.train import train | ||
|
||
from detectron2.modeling.mmdet_wrapper import MMDetDetector | ||
from detectron2.config import LazyCall as L | ||
|
||
model = L(MMDetDetector)( | ||
detector=dict( | ||
type="MaskRCNN", | ||
pretrained="torchvision://resnet50", | ||
backbone=dict( | ||
type="ResNet", | ||
depth=50, | ||
num_stages=4, | ||
out_indices=(0, 1, 2, 3), | ||
frozen_stages=1, | ||
norm_cfg=dict(type="BN", requires_grad=True), | ||
norm_eval=True, | ||
style="pytorch", | ||
), | ||
neck=dict(type="FPN", in_channels=[256, 512, 1024, 2048], out_channels=256, num_outs=5), | ||
rpn_head=dict( | ||
type="RPNHead", | ||
in_channels=256, | ||
feat_channels=256, | ||
anchor_generator=dict( | ||
type="AnchorGenerator", | ||
scales=[8], | ||
ratios=[0.5, 1.0, 2.0], | ||
strides=[4, 8, 16, 32, 64], | ||
), | ||
bbox_coder=dict( | ||
type="DeltaXYWHBBoxCoder", | ||
target_means=[0.0, 0.0, 0.0, 0.0], | ||
target_stds=[1.0, 1.0, 1.0, 1.0], | ||
), | ||
loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=True, loss_weight=1.0), | ||
loss_bbox=dict(type="L1Loss", loss_weight=1.0), | ||
), | ||
roi_head=dict( | ||
type="StandardRoIHead", | ||
bbox_roi_extractor=dict( | ||
type="SingleRoIExtractor", | ||
roi_layer=dict(type="RoIAlign", output_size=7, sampling_ratio=0), | ||
out_channels=256, | ||
featmap_strides=[4, 8, 16, 32], | ||
), | ||
bbox_head=dict( | ||
type="Shared2FCBBoxHead", | ||
in_channels=256, | ||
fc_out_channels=1024, | ||
roi_feat_size=7, | ||
num_classes=80, | ||
bbox_coder=dict( | ||
type="DeltaXYWHBBoxCoder", | ||
target_means=[0.0, 0.0, 0.0, 0.0], | ||
target_stds=[0.1, 0.1, 0.2, 0.2], | ||
), | ||
reg_class_agnostic=False, | ||
loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0), | ||
loss_bbox=dict(type="L1Loss", loss_weight=1.0), | ||
), | ||
mask_roi_extractor=dict( | ||
type="SingleRoIExtractor", | ||
roi_layer=dict(type="RoIAlign", output_size=14, sampling_ratio=0), | ||
out_channels=256, | ||
featmap_strides=[4, 8, 16, 32], | ||
), | ||
mask_head=dict( | ||
type="FCNMaskHead", | ||
num_convs=4, | ||
in_channels=256, | ||
conv_out_channels=256, | ||
num_classes=80, | ||
loss_mask=dict(type="CrossEntropyLoss", use_mask=True, loss_weight=1.0), | ||
), | ||
), | ||
# model training and testing settings | ||
train_cfg=dict( | ||
rpn=dict( | ||
assigner=dict( | ||
type="MaxIoUAssigner", | ||
pos_iou_thr=0.7, | ||
neg_iou_thr=0.3, | ||
min_pos_iou=0.3, | ||
match_low_quality=True, | ||
ignore_iof_thr=-1, | ||
), | ||
sampler=dict( | ||
type="RandomSampler", | ||
num=256, | ||
pos_fraction=0.5, | ||
neg_pos_ub=-1, | ||
add_gt_as_proposals=False, | ||
), | ||
allowed_border=-1, | ||
pos_weight=-1, | ||
debug=False, | ||
), | ||
rpn_proposal=dict( | ||
nms_pre=2000, | ||
max_per_img=1000, | ||
nms=dict(type="nms", iou_threshold=0.7), | ||
min_bbox_size=0, | ||
), | ||
rcnn=dict( | ||
assigner=dict( | ||
type="MaxIoUAssigner", | ||
pos_iou_thr=0.5, | ||
neg_iou_thr=0.5, | ||
min_pos_iou=0.5, | ||
match_low_quality=True, | ||
ignore_iof_thr=-1, | ||
), | ||
sampler=dict( | ||
type="RandomSampler", | ||
num=512, | ||
pos_fraction=0.25, | ||
neg_pos_ub=-1, | ||
add_gt_as_proposals=True, | ||
), | ||
mask_size=28, | ||
pos_weight=-1, | ||
debug=False, | ||
), | ||
), | ||
test_cfg=dict( | ||
rpn=dict( | ||
nms_pre=1000, | ||
max_per_img=1000, | ||
nms=dict(type="nms", iou_threshold=0.7), | ||
min_bbox_size=0, | ||
), | ||
rcnn=dict( | ||
score_thr=0.05, | ||
nms=dict(type="nms", iou_threshold=0.5), | ||
max_per_img=100, | ||
mask_thr_binary=0.5, | ||
), | ||
), | ||
), | ||
pixel_mean=[123.675, 116.280, 103.530], | ||
pixel_std=[58.395, 57.120, 57.375], | ||
) | ||
|
||
dataloader.train.mapper.image_format = "RGB" # torchvision pretrained model | ||
train.init_checkpoint = None # pretrained model is loaded inside backbone |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters