Skip to content

Commit

Permalink
High level load metrics (talmolab#573)
Browse files Browse the repository at this point in the history
- Add `sleap.load_metrics()`
- Example notebook
  • Loading branch information
talmo authored Aug 2, 2021
1 parent b3f71be commit 7ff2904
Show file tree
Hide file tree
Showing 5 changed files with 652 additions and 0 deletions.
585 changes: 585 additions & 0 deletions docs/notebooks/Model_evaluation.ipynb

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions docs/notebooks/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ This notebook can be a good place to start since you'll be able to see how train
Once you're ready to run training and inference on your own SLEAP dataset, this notebook walks you through the process of using `Google Drive <https://www.google.com/drive>`_ to copy data to and from Colab (as well as running training and inference on your dataset).


`Model evaluation <./Model_evaluation.html>`_
------------------------------------------------

After you've trained several models, you may want to compute some metrics for benchmarking and comparisons. This notebook walks through some of the types of metrics that SLEAP can compute for you, as well as how to recompute them.


`Analysis examples <./Analysis_examples.html>`_
------------------------------------------------

Expand All @@ -29,4 +35,5 @@ Once you've used SLEAP to successfully estimate animal pose and track animals in

Training_and_inference_on_an_example_dataset
Training_and_inference_using_Google_Drive
Model_evaluation
Analysis_examples
1 change: 1 addition & 0 deletions sleap/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
from sleap.nn.system import use_cpu_only, disable_preallocation
from sleap.nn.system import summary as system_summary
from sleap.nn.config import TrainingJobConfig, load_config
from sleap.nn.evals import load_metrics
40 changes: 40 additions & 0 deletions sleap/nn/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,3 +724,43 @@ def evaluate_model(
logger.info("OKS mAP: %f", metrics["oks_voc.mAP"])

return labels_pr, metrics


def load_metrics(model_path: str, split: str = "val") -> Dict[str, Any]:
"""Load metrics for a model.
Args:
model_path: Path to a model folder or metrics file (.npz).
split: Name of the split to load the metrics for. Must be `"train"`, `"val"` or
`"test"` (default: `"val"`). Ignored if a path to a metrics NPZ file is
provided.
Returns:
The loaded metrics as a dictionary with keys:
- `"vis.tp"`: Visibility - True Positives
- `"vis.fp"`: Visibility - False Positives
- `"vis.tn"`: Visibility - True Negatives
- `"vis.fn"`: Visibility - False Negatives
- `"vis.precision"`: Visibility - Precision
- `"vis.recall"`: Visibility - Recall
- `"dist.avg"`: Average Distance (ground truth vs prediction)
- `"dist.p50"`: Distance for 50th percentile
- `"dist.p75"`: Distance for 75th percentile
- `"dist.p90"`: Distance for 90th percentile
- `"dist.p95"`: Distance for 95th percentile
- `"dist.p99"`: Distance for 99th percentile
- `"dist.dists"`: All distances
- `"pck.mPCK"`: Mean Percentage of Correct Keypoints (PCK)
- `"oks.mOKS"`: Mean Object Keypoint Similarity (OKS)
- `"oks_voc.mAP"`: VOC with OKS scores - mean Average Precision (mAP)
- `"oks_voc.mAR"`: VOC with OKS scores - mean Average Recall (mAR)
- `"pck_voc.mAP"`: VOC with PCK scores - mean Average Precision (mAP)
- `"pck_voc.mAR"`: VOC with PCK scores - mean Average Recall (mAR)
"""
if os.path.isdir(model_path):
metrics_path = os.path.join(model_path, f"metrics.{split}.npz")
else:
metrics_path = model_path
with np.load(metrics_path, allow_pickle=True) as data:
return data["metrics"].item()
19 changes: 19 additions & 0 deletions tests/nn/test_evals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import numpy as np
import sleap
from sleap.nn.evals import load_metrics


sleap.use_cpu_only()


def test_load_metrics(min_centered_instance_model_path):
model_path = min_centered_instance_model_path

metrics = load_metrics(f"{model_path}/metrics.val.npz")
assert "oks_voc.mAP" in metrics

metrics = load_metrics(model_path, split="val")
assert "oks_voc.mAP" in metrics

metrics = load_metrics(model_path, split="train")
assert "oks_voc.mAP" in metrics

0 comments on commit 7ff2904

Please sign in to comment.