Skip to content

Commit

Permalink
Merge pull request mir-evaluation#178 from craffel/hierarchy_evaluator
Browse files Browse the repository at this point in the history
hierarchy.evalute() supports labels (NO-OP)
  • Loading branch information
craffel committed Feb 15, 2016
2 parents d570418 + 368bc11 commit 545596a
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 19 deletions.
11 changes: 8 additions & 3 deletions evaluators/segment_hier_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,16 @@ def process_arguments():
ref_files = parameters['reference_file']
est_files = parameters['estimated_file']

ref_intervals = [load_labeled_intervals(_)[0] for _ in ref_files]
est_intervals = [load_labeled_intervals(_)[0] for _ in est_files]
ref = [load_labeled_intervals(_) for _ in ref_files]
est = [load_labeled_intervals(_) for _ in est_files]
ref_intervals = [seg[0] for seg in ref]
ref_labels = [seg[1] for seg in ref]
est_intervals = [seg[0] for seg in est]
est_labels = [seg[1] for seg in est]

# Compute all the scores
scores = mir_eval.hierarchy.evaluate(ref_intervals, est_intervals,
scores = mir_eval.hierarchy.evaluate(ref_intervals, ref_labels,
est_intervals, est_labels,
window=parameters['window'])
print("{} [...] vs. {} [...]".format(
basename(parameters['reference_file'][0]),
Expand Down
34 changes: 23 additions & 11 deletions mir_eval/hierarchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,17 +362,20 @@ def tmeasure(reference_intervals_hier, estimated_intervals_hier,
return t_precision, t_recall, t_measure


def evaluate(ref_intervals_hier, est_intervals_hier, **kwargs):
def evaluate(ref_intervals_hier, ref_labels_hier,
est_intervals_hier, est_labels_hier, **kwargs):
'''Compute all hierarchical structure metrics for the given reference and
estimated annotations.
Examples
--------
A toy example with two two-layer annotations
>>> ref = [[[0, 30], [30, 60]], [[0, 15], [15, 30], [30, 45], [45, 60]]]
>>> est = [[[0, 45], [45, 60]], [[0, 15], [15, 30], [30, 45], [45, 60]]]
>>> scores = mir_eval.hierarchy.evaluate(ref, est)
>>> ref_i = [[[0, 30], [30, 60]], [[0, 15], [15, 30], [30, 45], [45, 60]]]
>>> est_i = [[[0, 45], [45, 60]], [[0, 15], [15, 30], [30, 45], [45, 60]]]
>>> ref_l = [ ['A', 'B'], ['a', 'b', 'a', 'c'] ]
>>> est_l = [ ['A', 'B'], ['a', 'a', 'b', 'b'] ]
>>> scores = mir_eval.hierarchy.evaluate(ref_i, ref_l, est_i, est_l)
>>> dict(scores)
{'T-Measure full': 0.94822745804853459,
'T-Measure reduced': 0.8732458222764804,
Expand All @@ -384,16 +387,22 @@ def evaluate(ref_intervals_hier, est_intervals_hier, **kwargs):
A more realistic example, using SALAMI pre-parsed annotations
>>> def load_salami(filename):
... "load SALAMI event format as unlabeled intervals"
... events = mir_eval.io.load_labeled_events(filename)[0]
... return mir_eval.util.boundaries_to_intervals(events)[0]
... "load SALAMI event format as labeled intervals"
... events, labels = mir_eval.io.load_labeled_events(filename)
... intervals = mir_eval.util.boundaries_to_intervals(events)[0]
... return intervals, labels[:len(intervals)]
>>> ref_files = ['data/10/parsed/textfile1_uppercase.txt',
... 'data/10/parsed/textfile1_lowercase.txt']
>>> est_files = ['data/10/parsed/textfile2_uppercase.txt',
... 'data/10/parsed/textfile2_lowercase.txt']
>>> ref_hier = [load_salami(fname) for fname in ref_files]
>>> est_hier = [load_salami(fname) for fname in est_files]
>>> scores = mir_eval.hierarchy.evaluate(ref_hier, est_hier)
>>> ref = [load_salami(fname) for fname in ref_files]
>>> ref_int = [seg[0] for seg in ref]
>>> ref_lab = [seg[1] for seg in ref]
>>> est = [load_salami(fname) for fname in est_files]
>>> est_int = [seg[0] for seg in est]
>>> est_lab = [seg[1] for seg in est]
>>> scores = mir_eval.hierarchy.evaluate(ref_int, ref_lab,
... est_hier, est_lab)
>>> dict(scores)
{'T-Measure full': 0.66029225561405358,
'T-Measure reduced': 0.62001868041578034,
Expand All @@ -406,10 +415,13 @@ def evaluate(ref_intervals_hier, est_intervals_hier, **kwargs):
Parameters
----------
ref_intervals_hier : list of list-like
ref_labels_hier : list of str
est_intervals_hier : list of list-like
est_labels_hier : list of str
Hierarchical annotations are encoded as an ordered list
of segmentations. Each segmentation itself is a list (or list-like)
of intervals.
of intervals (*_intervals_hier) and a list of lists of labels
(*_labels_hier).
kwargs
additional keyword arguments to the evaluation metrics.
Expand Down
17 changes: 12 additions & 5 deletions tests/test_hierarchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,19 @@ def test_tmeasure_regression():
est_files = sorted(glob('tests/data/hierarchy/est*.lab'))
out_files = sorted(glob('tests/data/hierarchy/output*.json'))

ref_hier = [mir_eval.io.load_labeled_intervals(_)[0] for _ in ref_files]
est_hier = [mir_eval.io.load_labeled_intervals(_)[0] for _ in est_files]
ref_hier = [mir_eval.io.load_labeled_intervals(_) for _ in ref_files]
est_hier = [mir_eval.io.load_labeled_intervals(_) for _ in est_files]

def __test(w, ref, est, target):
ref_ints = [seg[0] for seg in ref_hier]
ref_labs = [seg[1] for seg in ref_hier]
est_ints = [seg[0] for seg in est_hier]
est_labs = [seg[1] for seg in est_hier]

outputs = mir_eval.hierarchy.evaluate(ref, est, window=w)
def __test(w, ref_i, ref_l, est_i, est_l, target):

outputs = mir_eval.hierarchy.evaluate(ref_i, ref_l,
est_i, est_l,
window=w)

for key in target:
assert np.allclose(target[key], outputs[key], atol=A_TOL)
Expand All @@ -138,4 +145,4 @@ def __test(w, ref, est, target):

# Extract the window parameter
window = float(re.match('.*output_w=(\d+).json$', out).groups()[0])
yield __test, window, ref_hier, est_hier, target
yield __test, window, ref_ints, ref_labs, est_ints, est_labs, target

0 comments on commit 545596a

Please sign in to comment.