Skip to content

Commit

Permalink
[tune] Add leading zeros to checkpoint directory (ray-project#14152)
Browse files Browse the repository at this point in the history
* [tune] Add leading zeros to checkpoint directory

* Fix exp analysis tests/support string indices

* Fix tests

* RLLib tests
  • Loading branch information
krfricke authored Mar 1, 2021
1 parent 8572774 commit 7f9340b
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 11 deletions.
2 changes: 1 addition & 1 deletion python/ray/serve/examples/doc/tutorial_rllib.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def train_ppo_model():
# Train for one iteration
trainer.train()
trainer.save("/tmp/rllib_checkpoint")
return "/tmp/rllib_checkpoint/checkpoint_1/checkpoint-1"
return "/tmp/rllib_checkpoint/checkpoint_000001/checkpoint-1"


checkpoint_path = train_ppo_model()
Expand Down
9 changes: 5 additions & 4 deletions python/ray/tune/tests/test_experiment_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,30 +108,31 @@ def testGetTrialCheckpointsPathsByTrial(self):
best_trial = self.ea.get_best_trial(self.metric, mode="max")
checkpoints_metrics = self.ea.get_trial_checkpoints_paths(best_trial)
logdir = self.ea.get_best_logdir(self.metric, mode="max")
expected_path = os.path.join(logdir, "checkpoint_1", "checkpoint")
expected_path = os.path.join(logdir, "checkpoint_000001", "checkpoint")
assert checkpoints_metrics[0][0] == expected_path
assert checkpoints_metrics[0][1] == 1

def testGetTrialCheckpointsPathsByPath(self):
logdir = self.ea.get_best_logdir(self.metric, mode="max")
checkpoints_metrics = self.ea.get_trial_checkpoints_paths(logdir)
expected_path = os.path.join(logdir, "checkpoint_1/", "checkpoint")
expected_path = os.path.join(logdir, "checkpoint_000001/",
"checkpoint")
assert checkpoints_metrics[0][0] == expected_path
assert checkpoints_metrics[0][1] == 1

def testGetTrialCheckpointsPathsWithMetricByTrial(self):
best_trial = self.ea.get_best_trial(self.metric, mode="max")
paths = self.ea.get_trial_checkpoints_paths(best_trial, self.metric)
logdir = self.ea.get_best_logdir(self.metric, mode="max")
expected_path = os.path.join(logdir, "checkpoint_1", "checkpoint")
expected_path = os.path.join(logdir, "checkpoint_000001", "checkpoint")
assert paths[0][0] == expected_path
assert paths[0][1] == best_trial.metric_analysis[self.metric]["last"]

def testGetTrialCheckpointsPathsWithMetricByPath(self):
best_trial = self.ea.get_best_trial(self.metric, mode="max")
logdir = self.ea.get_best_logdir(self.metric, mode="max")
paths = self.ea.get_trial_checkpoints_paths(best_trial, self.metric)
expected_path = os.path.join(logdir, "checkpoint_1", "checkpoint")
expected_path = os.path.join(logdir, "checkpoint_000001", "checkpoint")
assert paths[0][0] == expected_path
assert paths[0][1] == best_trial.metric_analysis[self.metric]["last"]

Expand Down
4 changes: 2 additions & 2 deletions python/ray/tune/tests/test_tune_save_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,15 @@ def _train(self, exp_name, local_dir, absolute_local_dir):
self.assertTrue(os.path.isdir(abs_trial_dir))
self.assertTrue(
os.path.isfile(
os.path.join(abs_trial_dir, "checkpoint_1/checkpoint-1")))
os.path.join(abs_trial_dir, "checkpoint_000001/checkpoint-1")))

def _restore(self, exp_name, local_dir, absolute_local_dir):
trial_name, abs_trial_dir = self._get_trial_dir(
os.path.join(absolute_local_dir, exp_name))

checkpoint_path = os.path.join(
local_dir, exp_name, trial_name,
"checkpoint_1/checkpoint-1") # Relative checkpoint path
"checkpoint_000001/checkpoint-1") # Relative checkpoint path

# The file tune would find. The absolute checkpoint path.
tune_find_file = os.path.abspath(os.path.expanduser(checkpoint_path))
Expand Down
5 changes: 3 additions & 2 deletions python/ray/tune/utils/trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,15 @@ def make_checkpoint_dir(checkpoint_dir, index, override=False):
Args:
checkpoint_dir (str): Path to checkpoint directory.
index (str): A subdirectory will be created
index (int|str): A subdirectory will be created
at the checkpoint directory named 'checkpoint_{index}'.
override (bool): Deletes checkpoint_dir before creating
a new one.
"""
suffix = "checkpoint"
if index is not None:
suffix += "_{}".format(index)
suffix += f"_{index:06d}" if isinstance(index,
int) else f"_{index}"
checkpoint_dir = os.path.join(checkpoint_dir, suffix)

if override and os.path.exists(checkpoint_dir):
Expand Down
2 changes: 1 addition & 1 deletion rllib/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
rllib train --run DQN --env CartPole-v0
Example usage for rollout:
rllib rollout /trial_dir/checkpoint_1/checkpoint-1 --run DQN
rllib rollout /trial_dir/checkpoint_000001/checkpoint-1 --run DQN
"""


Expand Down
2 changes: 1 addition & 1 deletion rllib/tests/test_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def rollout_test(algo, env="CartPole-v0", test_episode_rollout=False):
"}' --stop='{\"training_iteration\": 1}'" +
" --env={}".format(env))

checkpoint_path = os.popen("ls {}/default/*/checkpoint_1/"
checkpoint_path = os.popen("ls {}/default/*/checkpoint_000001/"
"checkpoint-1".format(tmp_dir)).read()[:-1]
if not os.path.exists(checkpoint_path):
sys.exit(1)
Expand Down

0 comments on commit 7f9340b

Please sign in to comment.