Skip to content

Commit

Permalink
Detector post-processing
Browse files Browse the repository at this point in the history
  • Loading branch information
agentmorris committed Apr 26, 2019
1 parent a68f144 commit 8abbbf2
Showing 1 changed file with 130 additions and 19 deletions.
149 changes: 130 additions & 19 deletions api/detector_batch_processing/postprocess_batch_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#
# * Sample true/false positives/negatives and render to html (requires ground truth)
#
# * Sample detections/non-detections
# * Sample detections/non-detections and render to html (when ground truth isn't available)
#
########

Expand All @@ -22,11 +22,13 @@
import os
import sys
import argparse
import matplotlib.pyplot as plt
import pandas as pd
from enum import Enum
from tqdm import tqdm
from collections import defaultdict

from sklearn.metrics import precision_recall_curve, confusion_matrix, average_precision_score
from sklearn.utils.fixes import signature

#%% To be moved into options/inputs

Expand All @@ -38,6 +40,9 @@
negative_classes = ['empty']
confidence_threshold = 0.85

# Used for summary statistics only
target_recall = 0.9


#%% Helper classes and functions

Expand Down Expand Up @@ -159,9 +164,16 @@ def mark_detection_status(indexed_db,negative_classes=['empty']):
elif image_status == DetectionStatus.DS_AMBIGUOUS:
nAmbiguous += 1

im['_detection_status'] = image_status

return (nNegative,nPositive,nUnknown,nAmbiguous)


#%% Prepare output dir

os.makedirs(output_dir,exist_ok=True)


#%% Load ground truth if available

ground_truth_indexed_db = None
Expand Down Expand Up @@ -203,27 +215,126 @@ def mark_detection_status(indexed_db,negative_classes=['empty']):

#%% Find suspicious detections



#%% Fork here depending on whether or not ground truth is available

# If we have ground truth, we'll compute precision/recall and sample tp/fp/tn/fn.
#
# Otherwise we'll just visualize detections/non-detections.

if ground_truth_indexed_db is not None:

#%% If ground truth is available, match it to the detection results

class DetectionGroundTruth:

gt_image_id = None
gt_presence_label = None
gt_class_label = None


# For now, error on any matching failures

# Add columns gt_image_id, gt_presence_label, gt_class_label


#%% Make sure we can match ground truth to detection results

#%% Evaluate precision/recall, optionally rendering results
detector_files = detection_results['image_path'].to_list()

# For now, error on any matching failures, at some point we can decide
# how to handle "partial" ground truth. All or none for now.
for fn in detector_files:
assert fn in ground_truth_indexed_db.filename_to_id

print('Confirmed filename matches to ground truth for {} files'.format(len(detector_files)))


#%% Compute precision/recall

# numpy array of detection probabilities
p_detection = detection_results['max_confidence'].values
n_detections = len(p_detection)

# numpy array of bools (0.0/1.0)
gt_detections = np.zeros(n_detections,dtype=float)

for iDetection,fn in enumerate(detector_files):
image_id = ground_truth_indexed_db.filename_to_id[fn]
image = ground_truth_indexed_db.image_id_to_image[image_id]
detection_status = image['_detection_status']

if detection_status == DetectionStatus.DS_NEGATIVE:
gt_detections[iDetection] = 0.0
elif detection_status == DetectionStatus.DS_POSITIVE:
gt_detections[iDetection] = 1.0
else:
gt_detections[iDetection] = -1.0

# Don't include ambiguous/unknown ground truth in precision/recall analysis
b_valid_ground_truth = gt_detections >= 0.0

p_detection_pr = p_detection[b_valid_ground_truth]
gt_detections_pr = gt_detections[b_valid_ground_truth]

print('Including {} of {} values in p/r analysis'.format(np.sum(b_valid_ground_truth),
len(b_valid_ground_truth)))

precisions, recalls, thresholds = precision_recall_curve(gt_detections_pr, p_detection_pr)

# For completeness, include the result at a confidence threshold of 1.0
thresholds = np.append(thresholds, [1.0])

precisions_recalls = pd.DataFrame(data={
'confidence_threshold': thresholds,
'precision': precisions,
'recall': recalls
})

# Compute and print summary statistics
average_precision = average_precision_score(gt_detections_pr, p_detection_pr)
print('Average precision: {}'.format(average_precision))

# Thresholds go up throughout precisions/recalls/thresholds; find the last
# value where recall is at or above target. That's our precision @ target recall.
target_recall = 0.9
b_above_target_recall = np.where(recalls >= target_recall)
if not np.any(b_above_target_recall):
precision_at_target_recall = 0.0
else:
i_target_recall = np.argmax(b_above_target_recall)
precision_at_target_recall = precisions[i_target_recall]
print('Precision at {} recall: {}'.format(target_recall,precision_at_target_recall))

cm = confusion_matrix(gt_detections_pr, np.array(p_detection_pr) > confidence_threshold)

#%% Sample true/false positives/negatives and render to html
# Flatten the confusion matrix
tn, fp, fn, tp = cm.ravel()

precision = tp / (tp + fp)
recall = tp / (tp + fn)
f1 = 2 * (precision * recall) / (precision + recall)

print('At a confidence threshold of {:.2f}, precision={:.2f}, recall={:.2f}, f1={:.2f}'.format(
confidence_threshold,precision, recall, f1))


#%% Render output

# Write p/r table to .csv file in output directory
pr_table_filename = os.path.join(output_dir, 'prec_recall.csv')
precisions_recalls.to_csv(pr_table_filename, index=False)

# Write precision/recall plot to .png file in output directory
step_kwargs = ({'step': 'post'})
plt.step(recalls, precisions, color='b', alpha=0.2,
where='post')
plt.fill_between(recalls, precisions, alpha=0.2, color='b', **step_kwargs)

plt.xlabel('Recall')
plt.ylabel('Precision')
plt.ylim([0.0, 1.05])
plt.xlim([0.0, 1.05])
t = 'Precision-Recall curve: AP={:0.2f}, P@{:0.2f}={:0.2f}'.format(
average_precision, target_recall, precision_at_target_recall)
plt.title(t)
pr_figure_filename =os.path.join(output_dir, 'prec_recall.png')
plt.savefig(pr_figure_filename)
# plt.show()


#%% Sample true/false positives/negatives and render to html


#%% Sample detections/non-detections
else:

#%% Sample detections/non-detections

pass

0 comments on commit 8abbbf2

Please sign in to comment.