Skip to content

Commit

Permalink
fixed bug in metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
ericguizzo committed Dec 10, 2021
1 parent d12b87e commit 9cac538
Showing 1 changed file with 22 additions and 14 deletions.
36 changes: 22 additions & 14 deletions metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,18 @@ def location_sensitive_detection(pred, true, n_frames=100, spatial_threshold=2.,
for frame in range(n_frames):
t = frames[frame]['t'] #all true events for frame i
p = frames[frame]['p'] #all predicted events for frame i
matched = 0 #counts the matching events

if len(t) == 0: #if there are PREDICTED but not TRUE events
FP += len(p) #all predicted are false positive
elif len(p) == 0: #if there are TRUE but not PREDICTED events
FN += len(t) #all predicted are false negative

num_true_items = len(t)
num_pred_items = len(p)
matched = 0 #counts the matching events
match_ids = [] #all pred ids that matched

if num_true_items == 0: #if there are PREDICTED but not TRUE events
FP += num_pred_items #all predicted are false positive
elif num_pred_items == 0: #if there are TRUE but not PREDICTED events
FN += num_true_items #all predicted are false negative
elif num_true_items == 0 and num_pred_items == 0: #if no true and no predicted, just do nothing
pass
else:
for i_t in range(len(t)): #iterate all true events
match = False #flag for matching events
Expand All @@ -168,18 +173,21 @@ def location_sensitive_detection(pred, true, n_frames=100, spatial_threshold=2.,
spat_error = np.linalg.norm(true_coord-pred_coord) #cartesian distance between spatial coords
if true_class == pred_class and spat_error < spatial_threshold: #if predicton is correct (same label + not exceeding spatial error threshold)
match = True
match_ids.append(i_p) #append to matched ids (to cound eventual duplicates)
if match:
matched += 1 #for each true event, match only once comparing all predicted events

num_true_items = len(t)
num_pred_items = len(p)
fn = num_true_items - matched
fp = num_pred_items - matched
unique_ids = np.unique(match_ids) #remove duplicates from matches ids list
duplicates = len(match_ids) - len(unique_ids) #compute number of duplicates
matched = matched - duplicates
fn = num_true_items - matched
fp = num_pred_items - matched

#add to counts
TP += matched #number of matches are directly true positives
FN += fn
FP += fp

#add to counts
TP += matched #number of matches are directly true positives
FN += fn
FP += fp

precision = TP / (TP + FP + sys.float_info.epsilon)
recall = TP / (TP + FN + sys.float_info.epsilon)
Expand Down

0 comments on commit 9cac538

Please sign in to comment.