Skip to content

Commit

Permalink
[tune] Fix tests for Function API for better consistency (ray-project…
Browse files Browse the repository at this point in the history
  • Loading branch information
richardliaw authored Mar 21, 2019
1 parent 80ef8c1 commit 828dc08
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 24 deletions.
6 changes: 3 additions & 3 deletions python/ray/tune/function_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from ray.tune import TuneError
from ray.tune.trainable import Trainable
from ray.tune.result import TIME_THIS_ITER_S
from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -249,8 +249,8 @@ def _trainable_func(self, config, reporter):
output = train_func(config, reporter)
# If train_func returns, we need to notify the main event loop
# of the last result while avoiding double logging. This is done
# with the keyword "__duplicate__" -- see tune/trial_runner.py,
reporter(done=True, __duplicate__=True)
# with the keyword RESULT_DUPLICATE -- see tune/trial_runner.py.
reporter(**{RESULT_DUPLICATE: True})
return output

return WrappedFunc
4 changes: 4 additions & 0 deletions python/ray/tune/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@
# __sphinx_doc_end__
# yapf: enable

# __duplicate__ is a magic keyword used internally to
# avoid double-logging results when using the Function API.
RESULT_DUPLICATE = "__duplicate__"

# Where Tune writes result files by default
DEFAULT_RESULTS_DIR = (os.environ.get("TUNE_RESULT_DIR")
or os.path.expanduser("~/ray_results"))
Expand Down
36 changes: 25 additions & 11 deletions python/ray/tune/tests/test_trial_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
from ray.tune.schedulers import TrialScheduler, FIFOScheduler
from ray.tune.registry import _global_registry, TRAINABLE_CLASS
from ray.tune.result import (DEFAULT_RESULTS_DIR, TIMESTEPS_TOTAL, DONE,
EPISODES_TOTAL, TRAINING_ITERATION,
TIMESTEPS_THIS_ITER)
HOSTNAME, NODE_IP, PID, EPISODES_TOTAL,
TRAINING_ITERATION, TIMESTEPS_THIS_ITER,
TIME_THIS_ITER_S, TIME_TOTAL_S)
from ray.tune.logger import Logger
from ray.tune.util import pin_in_object_store, get_pinned_object
from ray.tune.experiment import Experiment
Expand Down Expand Up @@ -109,15 +110,28 @@ def _function_trainable(config, reporter):
raise_on_failed_trial=False,
scheduler=MockScheduler())

# Only compare these result fields. Metadata handling
# may be different across APIs.
COMPARE_FIELDS = {field for res in results for field in res}
# Ignore these fields
NO_COMPARE_FIELDS = {
HOSTNAME,
NODE_IP,
PID,
TIME_THIS_ITER_S,
TIME_TOTAL_S,
DONE, # This is ignored because FunctionAPI has different handling
"timestamp",
"time_since_restore",
"experiment_id",
"date",
}

self.assertEqual(len(class_output), len(results))
self.assertEqual(len(function_output), len(results))

def as_comparable_result(result):
return {k: v for k, v in result.items() if k in COMPARE_FIELDS}
return {
k: v
for k, v in result.items() if k not in NO_COMPARE_FIELDS
}

function_comparable = [
as_comparable_result(result) for result in function_output
Expand All @@ -133,6 +147,11 @@ def as_comparable_result(result):
as_comparable_result(scheduler_notif[0]),
as_comparable_result(scheduler_notif[1]))

# Make sure the last result is the same.
self.assertEqual(
as_comparable_result(trials[0].last_result),
as_comparable_result(trials[1].last_result))

return function_output, trials

def testPinObject(self):
Expand Down Expand Up @@ -583,11 +602,6 @@ def testNoDoneReceived(self):
# check if the correct number of results were reported.
self.assertEqual(len(logs1), len(results1))

# We should not double-log
trial = [t for t in trials if "function" in str(t)][0]
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result[DONE], False)

def check_no_missing(reported_result, result):
common_results = [reported_result[k] == result[k] for k in result]
return all(common_results)
Expand Down
11 changes: 8 additions & 3 deletions python/ray/tune/trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

import ray
from ray.tune.logger import UnifiedLogger
from ray.tune.result import (
DEFAULT_RESULTS_DIR, TIME_THIS_ITER_S, TIMESTEPS_THIS_ITER, DONE,
TIMESTEPS_TOTAL, EPISODES_THIS_ITER, EPISODES_TOTAL, TRAINING_ITERATION)
from ray.tune.result import (DEFAULT_RESULTS_DIR, TIME_THIS_ITER_S,
TIMESTEPS_THIS_ITER, DONE, TIMESTEPS_TOTAL,
EPISODES_THIS_ITER, EPISODES_TOTAL,
TRAINING_ITERATION, RESULT_DUPLICATE)
from ray.tune.trial import Resources

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -150,6 +151,10 @@ def train(self):
result = self._train()
assert isinstance(result, dict), "_train() needs to return a dict."

# We do not modify internal state nor update this result if duplicate.
if RESULT_DUPLICATE in result:
return result

result = result.copy()

self._iteration += 1
Expand Down
19 changes: 12 additions & 7 deletions python/ray/tune/trial_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from ray.tune import TuneError
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.result import TIME_THIS_ITER_S
from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE
from ray.tune.trial import Trial, Checkpoint
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
from ray.tune.util import warn_if_slow
Expand Down Expand Up @@ -407,6 +407,16 @@ def _process_events(self):
def _process_trial(self, trial):
try:
result = self.trial_executor.fetch_result(trial)

is_duplicate = RESULT_DUPLICATE in result
# TrialScheduler and SearchAlgorithm still receive a
# notification because there may be special handling for
# the `on_trial_complete` hook.
if is_duplicate:
logger.debug("Trial finished without logging 'done'.")
result = trial.last_result
result.update(done=True)

self._total_time += result[TIME_THIS_ITER_S]

if trial.should_stop(result):
Expand All @@ -426,12 +436,7 @@ def _process_trial(self, trial):
self._search_alg.on_trial_complete(
trial.trial_id, early_terminated=True)

# __duplicate__ is a magic keyword used internally to
# avoid double-logging results when using the Function API.
# TrialScheduler and SearchAlgorithm still receive a
# notification because there may be special handling for
# the `on_trial_complete` hook.
if "__duplicate__" not in result:
if not is_duplicate:
trial.update_last_result(
result, terminate=(decision == TrialScheduler.STOP))

Expand Down

0 comments on commit 828dc08

Please sign in to comment.