Skip to content

Commit

Permalink
fixed overlapping plots issue
Browse files Browse the repository at this point in the history
Signed-off-by: Enoch Kan <[email protected]>
  • Loading branch information
enochkan authored and bfirsh committed Mar 8, 2021
1 parent 5607908 commit 9ea30f5
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 15 deletions.
32 changes: 17 additions & 15 deletions python/keepsake/experiment.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,33 @@
try:
# backport is incompatible with 3.7+, so we must use built-in
from dataclasses import dataclass, InitVar, field
import dataclasses
from dataclasses import InitVar, dataclass, field
except ImportError:
from ._vendor.dataclasses import dataclass, InitVar, field
from ._vendor import dataclasses # type: ignore

import datetime
import getpass
import os
import math
import html
import datetime
import json
import math
import os
import shlex
import sys
from typing import (
Dict,
TYPE_CHECKING,
Any,
Optional,
Tuple,
Callable,
Dict,
List,
TYPE_CHECKING,
MutableSequence,
Callable,
Optional,
Tuple,
)

from . import console
from .checkpoint import (
Checkpoint,
PrimaryMetric,
CheckpointList,
)
from .metadata import rfc3339_datetime, parse_rfc3339
from .checkpoint import Checkpoint, CheckpointList, PrimaryMetric
from .metadata import parse_rfc3339, rfc3339_datetime
from .packages import get_imported_packages
from .system import get_python_version
from .validate import check_path
Expand Down Expand Up @@ -446,6 +443,11 @@ def plot(self, metric: Optional[str] = None, logy=False):
if metric is None:
metric = self.primary_metric()

plotted_label = plt.axes().yaxis.get_label().get_text() or metric

if metric != plotted_label:
plt.figure()

for exp in self:
exp.plot(metric, plot_only=True)

Expand Down
34 changes: 34 additions & 0 deletions python/tests/test_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import datetime

import matplotlib.pyplot as plt
from keepsake.checkpoint import Checkpoint, CheckpointList
from keepsake.experiment import Experiment, ExperimentList
from keepsake.project import Project, init


def test_num_plots(temp_workdir):
with open("keepsake.yaml", "w") as f:
f.write("repository: file://.keepsake/")

experiment = init(path=".", params={"learning_rate": 0.1, "num_epochs": 1},)

experiment.checkpoint(
path=".",
step=1,
metrics={"loss": 1.1836304664611816, "accuracy": 0.3333333432674408},
primary_metric=("loss", "minimize"),
)
experiment.checkpoint(
path=".",
step=2,
metrics={"loss": 1.1836304662222222, "accuracy": 0.4333333432674408},
primary_metric=("loss", "minimize"),
)

experiment_list = ExperimentList([experiment])
num_plots = 30
for rep in range(num_plots):
experiment_list.plot()
assert len(plt.get_fignums()) == 1
experiment_list.plot(metric="accuracy")
assert len(plt.get_fignums()) == 2
1 change: 1 addition & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ boto3==1.12.32
google-cloud-storage==1.32.0
waiting==1.4.1
python-dateutil==2.1
matplotlib==3.3.4

0 comments on commit 9ea30f5

Please sign in to comment.