Skip to content

Commit

Permalink
[tune] Add a time/timesteps since last restore metric (ray-project#2819)
Browse files Browse the repository at this point in the history
* rsm

* always log to avoid changing schema for csv writer

* add iter since restore

* update

* criteria warn
ericl authored Sep 6, 2018

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 995ac24 commit d81605e
Showing 4 changed files with 48 additions and 2 deletions.
1 change: 1 addition & 0 deletions python/ray/tune/logger.py
Original file line number Diff line number Diff line change
@@ -140,6 +140,7 @@ def on_result(self, result):
}, ["ray", "tune"])
iteration_stats = tf.Summary(value=iteration_value)
self._file_writer.add_summary(iteration_stats, t)
self._file_writer.flush()

def flush(self):
self._file_writer.flush()
32 changes: 32 additions & 0 deletions python/ray/tune/test/trial_runner_test.py
Original file line number Diff line number Diff line change
@@ -1009,6 +1009,38 @@ def testCheckpointing(self):
self.assertEqual(ray.get(trials[1].runner.get_info.remote()), 1)
self.addCleanup(os.remove, path)

def testRestoreMetricsAfterCheckpointing(self):
ray.init(num_cpus=1, num_gpus=1)
runner = TrialRunner(BasicVariantGenerator())
kwargs = {
"resources": Resources(cpu=1, gpu=1),
}
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()

runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1)
path = runner.trial_executor.save(trials[0])
runner.trial_executor.stop_trial(trials[0])
kwargs["restore_path"] = path

runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()

runner.step()
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertEqual(trials[1].status, Trial.RUNNING)
runner.step()
self.assertEqual(trials[1].last_result["timesteps_since_restore"], 10)
self.assertEqual(trials[1].last_result["iterations_since_restore"], 1)
self.assertGreater(trials[1].last_result["time_since_restore"], 0)
runner.step()
self.assertEqual(trials[1].last_result["timesteps_since_restore"], 20)
self.assertEqual(trials[1].last_result["iterations_since_restore"], 2)
self.assertGreater(trials[1].last_result["time_since_restore"], 0)
self.addCleanup(os.remove, path)

def testCheckpointingAtEnd(self):
ray.init(num_cpus=1, num_gpus=1)
runner = TrialRunner(BasicVariantGenerator())
13 changes: 12 additions & 1 deletion python/ray/tune/trainable.py
Original file line number Diff line number Diff line change
@@ -75,6 +75,10 @@ def __init__(self, config=None, logger_creator=None):
self._iteration = 0
self._time_total = 0.0
self._timesteps_total = None
self._time_since_restore = 0.0
self._timesteps_since_restore = 0
self._iterations_since_restore = 0
self._restored = False
self._setup()
self._initialize_ok = True
self._local_ip = ray.services.get_node_ip_address()
@@ -143,12 +147,14 @@ def train(self):
result = result.copy()

self._iteration += 1
self._iterations_since_restore += 1

if result.get(TIME_THIS_ITER_S) is not None:
time_this_iter = result[TIME_THIS_ITER_S]
else:
time_this_iter = time.time() - start
self._time_total += time_this_iter
self._time_since_restore += time_this_iter

result.setdefault(DONE, False)

@@ -157,6 +163,7 @@ def train(self):
if self._timesteps_total is None:
self._timesteps_total = 0
self._timesteps_total += result[TIMESTEPS_THIS_ITER]
self._timesteps_since_restore += result[TIMESTEPS_THIS_ITER]

# self._timesteps_total should not override user-provided total
result.setdefault(TIMESTEPS_TOTAL, self._timesteps_total)
@@ -176,7 +183,10 @@ def train(self):
pid=os.getpid(),
hostname=os.uname()[1],
node_ip=self._local_ip,
config=self.config)
config=self.config,
time_since_restore=self._time_since_restore,
timesteps_since_restore=self._timesteps_since_restore,
iterations_since_restore=self._iterations_since_restore)

self._result_logger.on_result(result)

@@ -248,6 +258,7 @@ def restore(self, checkpoint_path):
self._iteration = metadata[1]
self._timesteps_total = metadata[2]
self._time_total = metadata[3]
self._restored = True

def restore_from_object(self, obj):
"""Restores training state from a checkpoint object.
4 changes: 3 additions & 1 deletion python/ray/tune/trial.py
Original file line number Diff line number Diff line change
@@ -199,7 +199,9 @@ def should_stop(self, result):

for criteria, stop_value in self.stopping_criterion.items():
if criteria not in result:
raise TuneError("Stopping Criteria not provided in result.")
raise TuneError(
"Stopping criteria {} not provided in result {}.".format(
criteria, result))
if result[criteria] >= stop_value:
return True

0 comments on commit d81605e

Please sign in to comment.