Skip to content

Commit

Permalink
Support MR evaluation hook.
Browse files Browse the repository at this point in the history
* Add improved official citypersons evaluation script.
* Add CocoDistEvalMRHook.
* Support configuring eval_hook in config file.
* Assign MR eval_hook in cityperson/faster_rcnn_hrnet.py.
  • Loading branch information
Jokoe66 committed Oct 21, 2020
1 parent 18b984a commit 012e69b
Show file tree
Hide file tree
Showing 5 changed files with 568 additions and 15 deletions.
1 change: 1 addition & 0 deletions configs/elephant/cityperson/faster_rcnn_hrnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@
warmup_ratio=1.0 / 3,
step=[8, 11])
checkpoint_config = dict(interval=1)
evaluation = dict(interval=1, eval_hook='CocoDistEvalMRHook')
# yapf:disable
log_config = dict(
interval=50,
Expand Down
17 changes: 4 additions & 13 deletions mmdet/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from mmcv.runner import DistSamplerSeedHook, obj_from_dict
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel

from mmdet import datasets
from mmdet import datasets, core
from mmdet.core import (DistOptimizerHook, DistEvalmAPHook,
CocoDistEvalRecallHook, CocoDistEvalmAPHook,
Fp16OptimizerHook)
Expand Down Expand Up @@ -170,18 +170,9 @@ def _dist_train(model, dataset, cfg, validate=False):
if validate:
val_dataset_cfg = cfg.data.val
eval_cfg = cfg.get('evaluation', {})
if isinstance(model.module, RPN):
# TODO: implement recall hooks for other datasets
runner.register_hook(
CocoDistEvalRecallHook(val_dataset_cfg, **eval_cfg))
else:
dataset_type = getattr(datasets, val_dataset_cfg.type)
if issubclass(dataset_type, datasets.CocoDataset):
runner.register_hook(
CocoDistEvalmAPHook(val_dataset_cfg, **eval_cfg))
else:
runner.register_hook(
DistEvalmAPHook(val_dataset_cfg, **eval_cfg))
eval_hook = eval_cfg.pop('eval_hook', 'CocoDistEvalmAPHook')
EvalHook= getattr(core, eval_hook)
runner.register_hook(EvalHook(val_dataset_cfg, **eval_cfg))

if cfg.resume_from:
runner.resume(cfg.resume_from)
Expand Down
4 changes: 2 additions & 2 deletions mmdet/core/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
get_classes)
from .coco_utils import coco_eval, fast_eval_recall, results2json
from .eval_hooks import (DistEvalHook, DistEvalmAPHook, CocoDistEvalRecallHook,
CocoDistEvalmAPHook)
CocoDistEvalmAPHook, CocoDistEvalMRHook)
from .mean_ap import average_precision, eval_map, print_map_summary
from .recall import (eval_recalls, print_recall_summary, plot_num_recall,
plot_iou_recall)
Expand All @@ -14,5 +14,5 @@
'fast_eval_recall', 'results2json', 'DistEvalHook', 'DistEvalmAPHook',
'CocoDistEvalRecallHook', 'CocoDistEvalmAPHook', 'average_precision',
'eval_map', 'print_map_summary', 'eval_recalls', 'print_recall_summary',
'plot_num_recall', 'plot_iou_recall'
'plot_num_recall', 'plot_iou_recall', 'CocoDistEvalMRHook'
]
45 changes: 45 additions & 0 deletions mmdet/core/evaluation/eval_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from .coco_utils import results2json, fast_eval_recall
from .mean_ap import eval_map
from .eval_mr import COCOeval as COCOMReval
from mmdet import datasets


Expand Down Expand Up @@ -163,3 +164,47 @@ def evaluate(self, runner, results):
runner.log_buffer.ready = True
for res_type in res_types:
os.remove(result_files[res_type])


class CocoDistEvalMRHook(DistEvalHook):
""" EvalHook for MR evaluation.
Args:
res_types(list): detection type, currently support 'bbox'
and 'vis_bbox'.
"""
def __init__(self, dataset, interval=1, res_types=['bbox']):
super().__init__(dataset, interval)
self.res_types = res_types

def evaluate(self, runner, results):
tmp_file = osp.join(runner.work_dir, 'temp_0')
result_files = results2json(self.dataset, results, tmp_file)

cocoGt = self.dataset.coco
imgIds = cocoGt.getImgIds()
for res_type in self.res_types:
assert res_type in ['bbox', 'vis_bbox']
try:
cocoDt = cocoGt.loadRes(result_files['bbox'])
except IndexError:
print('No prediction found.')
break
metrics = ['MR_Reasonable', 'MR_Small', 'MR_Middle', 'MR_Large',
'MR_Bare', 'MR_Partial', 'MR_Heavy', 'MR_R+HO']
cocoEval = COCOMReval(cocoGt, cocoDt, res_type)
cocoEval.params.imgIds = imgIds
for id_setup in range(0,8):
cocoEval.evaluate(id_setup)
cocoEval.accumulate()
cocoEval.summarize(id_setup)

key = '{}'.format(metrics[id_setup])
val = float('{:.3f}'.format(cocoEval.stats[id_setup]))
runner.log_buffer.output[key] = val
runner.log_buffer.output['{}_MR_copypaste'.format(res_type)] = (
'{mr[0]:.3f} {mr[1]:.3f} {mr[2]:.3f} {mr[3]:.3f} '
'{mr[4]:.3f} {mr[5]:.3f} {mr[6]:.3f} {mr[7]:.3f} ').format(
mr=cocoEval.stats[:8])
runner.log_buffer.ready = True
os.remove(result_files['bbox'])
Loading

0 comments on commit 012e69b

Please sign in to comment.