-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcoco_evaluation.py
539 lines (467 loc) · 21.7 KB
/
coco_evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import contextlib
import copy
import io
import itertools
import json
import logging
import numpy as np
import os
import pickle
from collections import OrderedDict
import pycocotools.mask as mask_util
import torch
from fvcore.common.file_io import PathManager
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from tabulate import tabulate
import detectron2.utils.comm as comm
from detectron2.data import MetadataCatalog
from detectron2.data.datasets.coco import convert_to_coco_json
from detectron2.evaluation.fast_eval_api import COCOeval_opt
from detectron2.structures import Boxes, BoxMode, pairwise_iou
from detectron2.utils.logger import create_small_table
from .evaluator import DatasetEvaluator
class COCOEvaluator(DatasetEvaluator):
"""
Evaluate AR for object proposals, AP for instance detection/segmentation, AP
for keypoint detection outputs using COCO's metrics.
See http://cocodataset.org/#detection-eval and
http://cocodataset.org/#keypoints-eval to understand its metrics.
In addition to COCO, this evaluator is able to support any bounding box detection,
instance segmentation, or keypoint detection dataset.
"""
def __init__(self, dataset_name, cfg, distributed, output_dir=None, *, use_fast_impl=True):
"""
Args:
dataset_name (str): name of the dataset to be evaluated.
It must have either the following corresponding metadata:
"json_file": the path to the COCO format annotation
Or it must be in detectron2's standard dataset format
so it can be converted to COCO format automatically.
cfg (CfgNode): config instance
distributed (True): if True, will collect results from all ranks and run evaluation
in the main process.
Otherwise, will only evaluate the results in the current process.
output_dir (str): optional, an output directory to dump all
results predicted on the dataset. The dump contains two files:
1. "instance_predictions.pth" a file in torch serialization
format that contains all the raw original predictions.
2. "coco_instances_results.json" a json file in COCO's result
format.
use_fast_impl (bool): use a fast but **unofficial** implementation to compute AP.
Although the results should be very close to the official implementation in COCO
API, it is still recommended to compute results with the official API for use in
papers.
"""
self._tasks = self._tasks_from_config(cfg)
self._distributed = distributed
self._output_dir = output_dir
self._use_fast_impl = use_fast_impl
self._cpu_device = torch.device("cpu")
self._logger = logging.getLogger(__name__)
self._metadata = MetadataCatalog.get(dataset_name)
if not hasattr(self._metadata, "json_file"):
self._logger.info(
f"'{dataset_name}' is not registered by `register_coco_instances`."
" Therefore trying to convert it to COCO format ..."
)
cache_path = os.path.join(output_dir, f"{dataset_name}_coco_format.json")
self._metadata.json_file = cache_path
convert_to_coco_json(dataset_name, cache_path)
json_file = PathManager.get_local_path(self._metadata.json_file)
with contextlib.redirect_stdout(io.StringIO()):
self._coco_api = COCO(json_file)
self._kpt_oks_sigmas = cfg.TEST.KEYPOINT_OKS_SIGMAS
# Test set json files do not contain annotations (evaluation must be
# performed using the COCO evaluation server).
self._do_evaluation = "annotations" in self._coco_api.dataset
def reset(self):
self._predictions = []
def _tasks_from_config(self, cfg):
"""
Returns:
tuple[str]: tasks that can be evaluated under the given configuration.
"""
tasks = ("bbox",)
if cfg.MODEL.MASK_ON:
tasks = tasks + ("segm",)
if cfg.MODEL.KEYPOINT_ON:
tasks = tasks + ("keypoints",)
return tasks
def process(self, inputs, outputs):
"""
Args:
inputs: the inputs to a COCO model (e.g., GeneralizedRCNN).
It is a list of dict. Each dict corresponds to an image and
contains keys like "height", "width", "file_name", "image_id".
outputs: the outputs of a COCO model. It is a list of dicts with key
"instances" that contains :class:`Instances`.
"""
for input, output in zip(inputs, outputs):
prediction = {"image_id": input["image_id"]}
if "instances" in output:
instances = output["instances"].to(self._cpu_device)
prediction["instances"] = instances_to_coco_json(instances, input["image_id"])
if "proposals" in output:
prediction["proposals"] = output["proposals"].to(self._cpu_device)
self._predictions.append(prediction)
def evaluate(self):
if self._distributed:
comm.synchronize()
predictions = comm.gather(self._predictions, dst=0)
predictions = list(itertools.chain(*predictions))
if not comm.is_main_process():
return {}
else:
predictions = self._predictions
if len(predictions) == 0:
self._logger.warning("[COCOEvaluator] Did not receive valid predictions.")
return {}
if self._output_dir:
PathManager.mkdirs(self._output_dir)
file_path = os.path.join(self._output_dir, "instances_predictions.pth")
with PathManager.open(file_path, "wb") as f:
torch.save(predictions, f)
self._results = OrderedDict()
if "proposals" in predictions[0]:
self._eval_box_proposals(predictions)
if "instances" in predictions[0]:
self._eval_predictions(set(self._tasks), predictions)
# Copy so the caller can do whatever with results
return copy.deepcopy(self._results)
def _eval_predictions(self, tasks, predictions):
"""
Evaluate predictions on the given tasks.
Fill self._results with the metrics of the tasks.
"""
self._logger.info("Preparing results for COCO format ...")
coco_results = list(itertools.chain(*[x["instances"] for x in predictions]))
# unmap the category ids for COCO
if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"):
reverse_id_mapping = {
v: k for k, v in self._metadata.thing_dataset_id_to_contiguous_id.items()
}
for result in coco_results:
category_id = result["category_id"]
assert (
category_id in reverse_id_mapping
), "A prediction has category_id={}, which is not available in the dataset.".format(
category_id
)
result["category_id"] = reverse_id_mapping[category_id]
if self._output_dir:
file_path = os.path.join(self._output_dir, "coco_instances_results.json")
self._logger.info("Saving results to {}".format(file_path))
with PathManager.open(file_path, "w") as f:
f.write(json.dumps(coco_results))
f.flush()
if not self._do_evaluation:
self._logger.info("Annotations are not available for evaluation.")
return
self._logger.info(
"Evaluating predictions with {} COCO API...".format(
"unofficial" if self._use_fast_impl else "official"
)
)
for task in sorted(tasks):
coco_eval = (
_evaluate_predictions_on_coco(
self._coco_api,
coco_results,
task,
kpt_oks_sigmas=self._kpt_oks_sigmas,
use_fast_impl=self._use_fast_impl,
)
if len(coco_results) > 0
else None # cocoapi does not handle empty results very well
)
res = self._derive_coco_results(
coco_eval, task, class_names=self._metadata.get("thing_classes")
)
self._results[task] = res
def _eval_box_proposals(self, predictions):
"""
Evaluate the box proposals in predictions.
Fill self._results with the metrics for "box_proposals" task.
"""
if self._output_dir:
# Saving generated box proposals to file.
# Predicted box_proposals are in XYXY_ABS mode.
bbox_mode = BoxMode.XYXY_ABS.value
ids, boxes, objectness_logits = [], [], []
for prediction in predictions:
ids.append(prediction["image_id"])
boxes.append(prediction["proposals"].proposal_boxes.tensor.numpy())
objectness_logits.append(prediction["proposals"].objectness_logits.numpy())
proposal_data = {
"boxes": boxes,
"objectness_logits": objectness_logits,
"ids": ids,
"bbox_mode": bbox_mode,
}
with PathManager.open(os.path.join(self._output_dir, "box_proposals.pkl"), "wb") as f:
pickle.dump(proposal_data, f)
if not self._do_evaluation:
self._logger.info("Annotations are not available for evaluation.")
return
self._logger.info("Evaluating bbox proposals ...")
res = {}
areas = {"all": "", "small": "s", "medium": "m", "large": "l"}
for limit in [100, 1000]:
for area, suffix in areas.items():
stats = _evaluate_box_proposals(predictions, self._coco_api, area=area, limit=limit)
key = "AR{}@{:d}".format(suffix, limit)
res[key] = float(stats["ar"].item() * 100)
self._logger.info("Proposal metrics: \n" + create_small_table(res))
self._results["box_proposals"] = res
def _derive_coco_results(self, coco_eval, iou_type, class_names=None):
"""
Derive the desired score numbers from summarized COCOeval.
Args:
coco_eval (None or COCOEval): None represents no predictions from model.
iou_type (str):
class_names (None or list[str]): if provided, will use it to predict
per-category AP.
Returns:
a dict of {metric name: score}
"""
metrics = {
"bbox": ["AP", "AP10", "AP20", "AP30", "AP40", "AP50", "AP60", "AP70", "AP80", "AP90", "AR", "AR10", "AR20", "AR30", "AR40", "AR50", "AR60", "AR70", "AR80", "AR90" ],
"segm": ["AP", "AP10", "AP20", "AP30", "AP40", "AP50", "AP60", "AP70", "AP80", "AP90", "AR", "AR10", "AR20", "AR30", "AR40", "AR50", "AR60", "AR70", "AR80", "AR90" ],
"keypoints": ["AP", "AP10", "AP20", "AP30", "AP40", "AP50", "AP60", "AP70", "AP80", "AP90", "AR", "AR10", "AR20", "AR30", "AR40", "AR50", "AR60", "AR70", "AR80", "AR90" ],
}[iou_type]
if coco_eval is None:
self._logger.warn("No predictions from the model!")
return {metric: float("nan") for metric in metrics}
# the standard metrics
results = {
metric: float(coco_eval.stats[idx] * 100 if coco_eval.stats[idx] >= 0 else "nan")
for idx, metric in enumerate(metrics)
}
self._logger.info(
"Evaluation results for {}: \n".format(iou_type) + create_small_table(results)
)
if not np.isfinite(sum(results.values())):
self._logger.info("Some metrics cannot be computed and is shown as NaN.")
if class_names is None or len(class_names) <= 1:
return results
# Compute per-category AP
# from https://github.com/facebookresearch/Detectron/blob/a6a835f5b8208c45d0dce217ce9bbda915f44df7/detectron/datasets/json_dataset_evaluator.py#L222-L252 # noqa
precisions = coco_eval.eval["precision"]
# precision has dims (iou, recall, cls, area range, max dets)
assert len(class_names) == precisions.shape[2]
results_per_category = []
for idx, name in enumerate(class_names):
# area range index 0: all area ranges
# max dets index -1: typically 100 per image
precision = precisions[:, :, idx, 0, -1]
precision = precision[precision > -1]
ap = np.mean(precision) if precision.size else float("nan")
results_per_category.append(("{}".format(name), float(ap * 100)))
# tabulate it
N_COLS = min(6, len(results_per_category) * 2)
results_flatten = list(itertools.chain(*results_per_category))
results_2d = itertools.zip_longest(*[results_flatten[i::N_COLS] for i in range(N_COLS)])
table = tabulate(
results_2d,
tablefmt="pipe",
floatfmt=".3f",
headers=["category", "AP"] * (N_COLS // 2),
numalign="left",
)
self._logger.info("Per-category {} AP: \n".format(iou_type) + table)
results.update({"AP-" + name: ap for name, ap in results_per_category})
return results
def instances_to_coco_json(instances, img_id):
"""
Dump an "Instances" object to a COCO-format json that's used for evaluation.
Args:
instances (Instances):
img_id (int): the image id
Returns:
list[dict]: list of json annotations in COCO format.
"""
num_instance = len(instances)
if num_instance == 0:
return []
boxes = instances.pred_boxes.tensor.numpy()
boxes = BoxMode.convert(boxes, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
boxes = boxes.tolist()
scores = instances.scores.tolist()
classes = instances.pred_classes.tolist()
has_mask = instances.has("pred_masks")
if has_mask:
# use RLE to encode the masks, because they are too large and takes memory
# since this evaluator stores outputs of the entire dataset
rles = [
mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0]
for mask in instances.pred_masks
]
for rle in rles:
# "counts" is an array encoded by mask_util as a byte-stream. Python3's
# json writer which always produces strings cannot serialize a bytestream
# unless you decode it. Thankfully, utf-8 works out (which is also what
# the pycocotools/_mask.pyx does).
rle["counts"] = rle["counts"].decode("utf-8")
has_keypoints = instances.has("pred_keypoints")
if has_keypoints:
keypoints = instances.pred_keypoints
results = []
for k in range(num_instance):
result = {
"image_id": img_id,
"category_id": classes[k],
"bbox": boxes[k],
"score": scores[k],
}
if has_mask:
result["segmentation"] = rles[k]
if has_keypoints:
# In COCO annotations,
# keypoints coordinates are pixel indices.
# However our predictions are floating point coordinates.
# Therefore we subtract 0.5 to be consistent with the annotation format.
# This is the inverse of data loading logic in `datasets/coco.py`.
keypoints[k][:, :2] -= 0.5
result["keypoints"] = keypoints[k].flatten().tolist()
results.append(result)
return results
# inspired from Detectron:
# https://github.com/facebookresearch/Detectron/blob/a6a835f5b8208c45d0dce217ce9bbda915f44df7/detectron/datasets/json_dataset_evaluator.py#L255 # noqa
def _evaluate_box_proposals(dataset_predictions, coco_api, thresholds=None, area="all", limit=None):
"""
Evaluate detection proposal recall metrics. This function is a much
faster alternative to the official COCO API recall evaluation code. However,
it produces slightly different results.
"""
# Record max overlap value for each gt box
# Return vector of overlap values
areas = {
"all": 0,
"small": 1,
"medium": 2,
"large": 3,
"96-128": 4,
"128-256": 5,
"256-512": 6,
"512-inf": 7,
}
area_ranges = [
[0 ** 2, 1e5 ** 2], # all
[0 ** 2, 32 ** 2], # small
[32 ** 2, 96 ** 2], # medium
[96 ** 2, 1e5 ** 2], # large
[96 ** 2, 128 ** 2], # 96-128
[128 ** 2, 256 ** 2], # 128-256
[256 ** 2, 512 ** 2], # 256-512
[512 ** 2, 1e5 ** 2],
] # 512-inf
assert area in areas, "Unknown area range: {}".format(area)
area_range = area_ranges[areas[area]]
gt_overlaps = []
num_pos = 0
for prediction_dict in dataset_predictions:
predictions = prediction_dict["proposals"]
# sort predictions in descending order
# TODO maybe remove this and make it explicit in the documentation
inds = predictions.objectness_logits.sort(descending=True)[1]
predictions = predictions[inds]
ann_ids = coco_api.getAnnIds(imgIds=prediction_dict["image_id"])
anno = coco_api.loadAnns(ann_ids)
gt_boxes = [
BoxMode.convert(obj["bbox"], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS)
for obj in anno
if obj["iscrowd"] == 0
]
gt_boxes = torch.as_tensor(gt_boxes).reshape(-1, 4) # guard against no boxes
gt_boxes = Boxes(gt_boxes)
gt_areas = torch.as_tensor([obj["area"] for obj in anno if obj["iscrowd"] == 0])
if len(gt_boxes) == 0 or len(predictions) == 0:
continue
valid_gt_inds = (gt_areas >= area_range[0]) & (gt_areas <= area_range[1])
gt_boxes = gt_boxes[valid_gt_inds]
num_pos += len(gt_boxes)
if len(gt_boxes) == 0:
continue
if limit is not None and len(predictions) > limit:
predictions = predictions[:limit]
overlaps = pairwise_iou(predictions.proposal_boxes, gt_boxes)
_gt_overlaps = torch.zeros(len(gt_boxes))
for j in range(min(len(predictions), len(gt_boxes))):
# find which proposal box maximally covers each gt box
# and get the iou amount of coverage for each gt box
max_overlaps, argmax_overlaps = overlaps.max(dim=0)
# find which gt box is 'best' covered (i.e. 'best' = most iou)
gt_ovr, gt_ind = max_overlaps.max(dim=0)
assert gt_ovr >= 0
# find the proposal box that covers the best covered gt box
box_ind = argmax_overlaps[gt_ind]
# record the iou coverage of this gt box
_gt_overlaps[j] = overlaps[box_ind, gt_ind]
assert _gt_overlaps[j] == gt_ovr
# mark the proposal box and the gt box as used
overlaps[box_ind, :] = -1
overlaps[:, gt_ind] = -1
# append recorded iou coverage level
gt_overlaps.append(_gt_overlaps)
gt_overlaps = (
torch.cat(gt_overlaps, dim=0) if len(gt_overlaps) else torch.zeros(0, dtype=torch.float32)
)
gt_overlaps, _ = torch.sort(gt_overlaps)
if thresholds is None:
step = 0.05
thresholds = torch.arange(0.5, 0.95 + 1e-5, step, dtype=torch.float32)
recalls = torch.zeros_like(thresholds)
# compute recall for each iou threshold
for i, t in enumerate(thresholds):
recalls[i] = (gt_overlaps >= t).float().sum() / float(num_pos)
# ar = 2 * np.trapz(recalls, thresholds)
ar = recalls.mean()
return {
"ar": ar,
"recalls": recalls,
"thresholds": thresholds,
"gt_overlaps": gt_overlaps,
"num_pos": num_pos,
}
def _evaluate_predictions_on_coco(
coco_gt, coco_results, iou_type, kpt_oks_sigmas=None, use_fast_impl=True
):
"""
Evaluate the coco results using COCOEval API.
"""
assert len(coco_results) > 0
if iou_type == "segm":
coco_results = copy.deepcopy(coco_results)
# When evaluating mask AP, if the results contain bbox, cocoapi will
# use the box area as the area of the instance, instead of the mask area.
# This leads to a different definition of small/medium/large.
# We remove the bbox field to let mask AP use mask area.
for c in coco_results:
c.pop("bbox", None)
coco_dt = coco_gt.loadRes(coco_results)
coco_eval = (COCOeval_opt if use_fast_impl else COCOeval)(coco_gt, coco_dt, iou_type)
if iou_type == "keypoints":
# Use the COCO default keypoint OKS sigmas unless overrides are specified
if kpt_oks_sigmas:
assert hasattr(coco_eval.params, "kpt_oks_sigmas"), "pycocotools is too old!"
coco_eval.params.kpt_oks_sigmas = np.array(kpt_oks_sigmas)
# COCOAPI requires every detection and every gt to have keypoints, so
# we just take the first entry from both
num_keypoints_dt = len(coco_results[0]["keypoints"]) // 3
num_keypoints_gt = len(next(iter(coco_gt.anns.values()))["keypoints"]) // 3
num_keypoints_oks = len(coco_eval.params.kpt_oks_sigmas)
assert num_keypoints_oks == num_keypoints_dt == num_keypoints_gt, (
f"[COCOEvaluator] Prediction contain {num_keypoints_dt} keypoints. "
f"Ground truth contains {num_keypoints_gt} keypoints. "
f"The length of cfg.TEST.KEYPOINT_OKS_SIGMAS is {num_keypoints_oks}. "
"They have to agree with each other. For meaning of OKS, please refer to "
"http://cocodataset.org/#keypoints-eval."
)
coco_eval.params.maxDets = [500]
coco_eval.evaluate()
coco_eval.accumulate()
#coco_eval.summarize()
coco_eval.summarize2()
return coco_eval