forked from longcw/MOTDT
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
315 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from .mot_seq import MOTSeq | ||
|
||
|
||
__factory = { | ||
# 'kitti': KITTISeq, | ||
'mot': MOTSeq, | ||
} | ||
|
||
|
||
def get_names(): | ||
return tuple(__factory.keys()) | ||
|
||
|
||
def init_dataset(name, *args, **kwargs): | ||
if name not in get_names(): | ||
raise KeyError("Unknown dataset: {}".format(name)) | ||
return __factory[name](*args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,4 +8,5 @@ cffi | |
sklearn | ||
numba | ||
torchvision | ||
fire | ||
fire | ||
motmetrics |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import os | ||
import numpy as np | ||
import copy | ||
import motmetrics as mm | ||
|
||
from utils.io import read_results, unzip_objs | ||
|
||
|
||
class Evaluator(object): | ||
|
||
def __init__(self, data_root, seq_name, data_type): | ||
self.data_root = data_root | ||
self.seq_name = seq_name | ||
self.data_type = data_type | ||
|
||
self.load_annotations() | ||
self.reset_accumulator() | ||
|
||
def load_annotations(self): | ||
assert self.data_type == 'mot' | ||
|
||
gt_filename = os.path.join(self.data_root, self.seq_name, 'gt', 'gt.txt') | ||
self.gt_frame_dict = read_results(gt_filename, self.data_type, is_gt=True) | ||
self.gt_ignore_frame_dict = read_results(gt_filename, self.data_type, is_ignore=True) | ||
|
||
def reset_accumulator(self): | ||
self.acc = mm.MOTAccumulator(auto_id=True) | ||
|
||
def eval_frame(self, frame_id, trk_tlwhs, trk_ids, rtn_events=False): | ||
# results | ||
trk_tlwhs = np.copy(trk_tlwhs) | ||
trk_ids = np.copy(trk_ids) | ||
|
||
# gts | ||
gt_objs = self.gt_frame_dict.get(frame_id, []) | ||
gt_tlwhs, gt_ids = unzip_objs(gt_objs)[:2] | ||
|
||
# ignore boxes | ||
ignore_objs = self.gt_ignore_frame_dict.get(frame_id, []) | ||
ignore_tlwhs = unzip_objs(ignore_objs)[0] | ||
|
||
# remove ignored results | ||
keep = np.ones(len(trk_tlwhs), dtype=bool) | ||
iou_distance = mm.distances.iou_matrix(ignore_tlwhs, trk_tlwhs, max_iou=0.5) | ||
match_is, match_js = mm.lap.linear_sum_assignment(iou_distance) | ||
match_is, match_js = map(lambda a: np.asarray(a, dtype=int), [match_is, match_js]) | ||
match_ious = iou_distance[match_is, match_js] | ||
|
||
match_js = np.asarray(match_js, dtype=int) | ||
match_js = match_js[np.logical_not(np.isnan(match_ious))] | ||
keep[match_js] = False | ||
trk_tlwhs = trk_tlwhs[keep] | ||
trk_ids = trk_ids[keep] | ||
|
||
# get distance matrix | ||
iou_distance = mm.distances.iou_matrix(gt_tlwhs, trk_tlwhs, max_iou=0.5) | ||
|
||
# acc | ||
self.acc.update(gt_ids, trk_ids, iou_distance) | ||
|
||
if rtn_events and iou_distance.size > 0 and hasattr(self.acc, 'last_mot_events'): | ||
events = self.acc.last_mot_events # only supported by https://github.com/longcw/py-motmetrics | ||
else: | ||
events = None | ||
return events | ||
|
||
def eval_file(self, filename): | ||
self.reset_accumulator() | ||
|
||
result_frame_dict = read_results(filename, self.data_type, is_gt=False) | ||
frames = sorted(list(set(self.gt_frame_dict.keys()) | set(result_frame_dict.keys()))) | ||
for frame_id in frames: | ||
trk_objs = result_frame_dict.get(frame_id, []) | ||
trk_tlwhs, trk_ids = unzip_objs(trk_objs)[:2] | ||
self.eval_frame(frame_id, trk_tlwhs, trk_ids, rtn_events=False) | ||
|
||
return self.acc | ||
|
||
@staticmethod | ||
def get_summary(accs, names, metrics=('mota', 'num_switches', 'idp', 'idr', 'idf1', 'precision', 'recall')): | ||
names = copy.deepcopy(names) | ||
if metrics is None: | ||
metrics = mm.metrics.motchallenge_metrics | ||
metrics = copy.deepcopy(metrics) | ||
|
||
mh = mm.metrics.create() | ||
summary = mh.compute_many( | ||
accs, | ||
metrics=metrics, | ||
names=names, | ||
generate_overall=True | ||
) | ||
|
||
return summary | ||
|
||
@staticmethod | ||
def save_summary(summary, filename): | ||
import pandas as pd | ||
writer = pd.ExcelWriter(filename) | ||
summary.to_excel(writer) | ||
writer.save() |
Oops, something went wrong.