From 8abbbf2c69d5183ff75a52de38d1d1c3619ac02e Mon Sep 17 00:00:00 2001 From: Dan Morris Date: Fri, 26 Apr 2019 09:26:57 -0700 Subject: [PATCH] Detector post-processing --- .../postprocess_batch_results.py | 149 +++++++++++++++--- 1 file changed, 130 insertions(+), 19 deletions(-) diff --git a/api/detector_batch_processing/postprocess_batch_results.py b/api/detector_batch_processing/postprocess_batch_results.py index d0e5b0895..75377f50f 100644 --- a/api/detector_batch_processing/postprocess_batch_results.py +++ b/api/detector_batch_processing/postprocess_batch_results.py @@ -12,7 +12,7 @@ # # * Sample true/false positives/negatives and render to html (requires ground truth) # -# * Sample detections/non-detections +# * Sample detections/non-detections and render to html (when ground truth isn't available) # ######## @@ -22,11 +22,13 @@ import os import sys import argparse +import matplotlib.pyplot as plt import pandas as pd from enum import Enum from tqdm import tqdm from collections import defaultdict - +from sklearn.metrics import precision_recall_curve, confusion_matrix, average_precision_score +from sklearn.utils.fixes import signature #%% To be moved into options/inputs @@ -38,6 +40,9 @@ negative_classes = ['empty'] confidence_threshold = 0.85 +# Used for summary statistics only +target_recall = 0.9 + #%% Helper classes and functions @@ -159,9 +164,16 @@ def mark_detection_status(indexed_db,negative_classes=['empty']): elif image_status == DetectionStatus.DS_AMBIGUOUS: nAmbiguous += 1 + im['_detection_status'] = image_status + return (nNegative,nPositive,nUnknown,nAmbiguous) +#%% Prepare output dir + +os.makedirs(output_dir,exist_ok=True) + + #%% Load ground truth if available ground_truth_indexed_db = None @@ -203,27 +215,126 @@ def mark_detection_status(indexed_db,negative_classes=['empty']): #%% Find suspicious detections + +#%% Fork here depending on whether or not ground truth is available + +# If we have ground truth, we'll compute precision/recall and sample tp/fp/tn/fn. +# +# Otherwise we'll just visualize detections/non-detections. + +if ground_truth_indexed_db is not None: -#%% If ground truth is available, match it to the detection results - -class DetectionGroundTruth: - - gt_image_id = None - gt_presence_label = None - gt_class_label = None - - -# For now, error on any matching failures - -# Add columns gt_image_id, gt_presence_label, gt_class_label - - + #%% Make sure we can match ground truth to detection results -#%% Evaluate precision/recall, optionally rendering results + detector_files = detection_results['image_path'].to_list() + + # For now, error on any matching failures, at some point we can decide + # how to handle "partial" ground truth. All or none for now. + for fn in detector_files: + assert fn in ground_truth_indexed_db.filename_to_id + + print('Confirmed filename matches to ground truth for {} files'.format(len(detector_files))) + + #%% Compute precision/recall + + # numpy array of detection probabilities + p_detection = detection_results['max_confidence'].values + n_detections = len(p_detection) + + # numpy array of bools (0.0/1.0) + gt_detections = np.zeros(n_detections,dtype=float) + + for iDetection,fn in enumerate(detector_files): + image_id = ground_truth_indexed_db.filename_to_id[fn] + image = ground_truth_indexed_db.image_id_to_image[image_id] + detection_status = image['_detection_status'] + + if detection_status == DetectionStatus.DS_NEGATIVE: + gt_detections[iDetection] = 0.0 + elif detection_status == DetectionStatus.DS_POSITIVE: + gt_detections[iDetection] = 1.0 + else: + gt_detections[iDetection] = -1.0 + + # Don't include ambiguous/unknown ground truth in precision/recall analysis + b_valid_ground_truth = gt_detections >= 0.0 + + p_detection_pr = p_detection[b_valid_ground_truth] + gt_detections_pr = gt_detections[b_valid_ground_truth] + + print('Including {} of {} values in p/r analysis'.format(np.sum(b_valid_ground_truth), + len(b_valid_ground_truth))) + + precisions, recalls, thresholds = precision_recall_curve(gt_detections_pr, p_detection_pr) + + # For completeness, include the result at a confidence threshold of 1.0 + thresholds = np.append(thresholds, [1.0]) + + precisions_recalls = pd.DataFrame(data={ + 'confidence_threshold': thresholds, + 'precision': precisions, + 'recall': recalls + }) + + # Compute and print summary statistics + average_precision = average_precision_score(gt_detections_pr, p_detection_pr) + print('Average precision: {}'.format(average_precision)) + + # Thresholds go up throughout precisions/recalls/thresholds; find the last + # value where recall is at or above target. That's our precision @ target recall. + target_recall = 0.9 + b_above_target_recall = np.where(recalls >= target_recall) + if not np.any(b_above_target_recall): + precision_at_target_recall = 0.0 + else: + i_target_recall = np.argmax(b_above_target_recall) + precision_at_target_recall = precisions[i_target_recall] + print('Precision at {} recall: {}'.format(target_recall,precision_at_target_recall)) + + cm = confusion_matrix(gt_detections_pr, np.array(p_detection_pr) > confidence_threshold) -#%% Sample true/false positives/negatives and render to html + # Flatten the confusion matrix + tn, fp, fn, tp = cm.ravel() + precision = tp / (tp + fp) + recall = tp / (tp + fn) + f1 = 2 * (precision * recall) / (precision + recall) + + print('At a confidence threshold of {:.2f}, precision={:.2f}, recall={:.2f}, f1={:.2f}'.format( + confidence_threshold,precision, recall, f1)) + + + #%% Render output + + # Write p/r table to .csv file in output directory + pr_table_filename = os.path.join(output_dir, 'prec_recall.csv') + precisions_recalls.to_csv(pr_table_filename, index=False) + + # Write precision/recall plot to .png file in output directory + step_kwargs = ({'step': 'post'}) + plt.step(recalls, precisions, color='b', alpha=0.2, + where='post') + plt.fill_between(recalls, precisions, alpha=0.2, color='b', **step_kwargs) + + plt.xlabel('Recall') + plt.ylabel('Precision') + plt.ylim([0.0, 1.05]) + plt.xlim([0.0, 1.05]) + t = 'Precision-Recall curve: AP={:0.2f}, P@{:0.2f}={:0.2f}'.format( + average_precision, target_recall, precision_at_target_recall) + plt.title(t) + pr_figure_filename =os.path.join(output_dir, 'prec_recall.png') + plt.savefig(pr_figure_filename) + # plt.show() + + + #%% Sample true/false positives/negatives and render to html + -#%% Sample detections/non-detections \ No newline at end of file +else: + + #%% Sample detections/non-detections + + pass \ No newline at end of file