Skip to content

Commit

Permalink
[Tune] Don't include nan metrics for best checkpoint (ray-project#23820)
Browse files Browse the repository at this point in the history
Nan values do not have a well defined ordering. When sorting metrics to determine the best checkpoint, we should always filter out checkpoints that are associated with nan values.

Closes ray-project#23812
  • Loading branch information
amogkam authored Apr 11, 2022
1 parent d8efa37 commit d33483d
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 3 deletions.
10 changes: 9 additions & 1 deletion python/ray/tune/analysis/experiment_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ray.tune.syncer import SyncConfig
from ray.tune.utils import flatten_dict
from ray.tune.utils.serialization import TuneFunctionDecoder
from ray.tune.utils.util import is_nan_or_inf
from ray.tune.utils.util import is_nan_or_inf, is_nan

try:
import pandas as pd
Expand Down Expand Up @@ -433,6 +433,8 @@ def get_best_checkpoint(
) -> Optional[Checkpoint]:
"""Gets best persistent checkpoint path of provided trial.
Any checkpoints with an associated metric value of ``nan`` will be filtered out.
Args:
trial: The log directory of a trial, or a trial instance.
metric: key of trial info to return, e.g. "mean_accuracy".
Expand All @@ -447,6 +449,12 @@ def get_best_checkpoint(
mode = self._validate_mode(mode)

checkpoint_paths = self.get_trial_checkpoints_paths(trial, metric)

# Filter out nan. Sorting nan values leads to undefined behavior.
checkpoint_paths = [
(path, metric) for path, metric in checkpoint_paths if not is_nan(metric)
]

if not checkpoint_paths:
logger.error(f"No checkpoints have been found for trial {trial}.")
return None
Expand Down
34 changes: 33 additions & 1 deletion python/ray/tune/tests/test_experiment_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ray.tune import ExperimentAnalysis
import ray.tune.registry
from ray.tune.utils.mock_trainable import MyTrainableClass
from ray.tune.utils.util import is_nan


class ExperimentAnalysisSuite(unittest.TestCase):
Expand Down Expand Up @@ -152,13 +153,44 @@ def testGetTrialCheckpointsPathsWithMetricByPath(self):

def testGetBestCheckpoint(self):
best_trial = self.ea.get_best_trial(self.metric, mode="max")
checkpoints_metrics = self.ea.get_trial_checkpoints_paths(best_trial)
checkpoints_metrics = self.ea.get_trial_checkpoints_paths(
best_trial, metric=self.metric
)
expected_path = max(checkpoints_metrics, key=lambda x: x[1])[0]
best_checkpoint = self.ea.get_best_checkpoint(
best_trial, self.metric, mode="max"
)
assert expected_path == best_checkpoint

def testGetBestCheckpointNan(self):
"""Tests if nan values are excluded from best checkpoint."""
metric = "loss"

def train(config):
for i in range(config["steps"]):
if i == 0:
value = float("nan")
else:
value = i
result = {metric: value}
with tune.checkpoint_dir(step=i):
pass
tune.report(**result)

ea = tune.run(train, local_dir=self.test_dir, config={"steps": 3})
best_trial = ea.get_best_trial(metric, mode="min")
best_checkpoint = ea.get_best_checkpoint(best_trial, metric, mode="min")
checkpoints_metrics = ea.get_trial_checkpoints_paths(best_trial, metric=metric)
expected_checkpoint_no_nan = min(
[
checkpoint_metric
for checkpoint_metric in checkpoints_metrics
if not is_nan(checkpoint_metric[1])
],
key=lambda x: x[1],
)[0]
assert best_checkpoint == expected_checkpoint_no_nan

def testGetLastCheckpoint(self):
# one more experiment with 2 iterations
new_ea = tune.run(
Expand Down
6 changes: 5 additions & 1 deletion python/ray/tune/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,12 @@ def date_str():
return datetime.today().strftime("%Y-%m-%d_%H-%M-%S")


def is_nan(value):
return np.isnan(value)


def is_nan_or_inf(value):
return np.isnan(value) or np.isinf(value)
return is_nan(value) or np.isinf(value)


def _to_pinnable(obj):
Expand Down

0 comments on commit d33483d

Please sign in to comment.