Skip to content

Commit

Permalink
[train] Storage: Change class trainable save_checkpoint implementatio…
Browse files Browse the repository at this point in the history
…ns (ray-project#38554)

Some of our (internal) class trainables return a subpath in `save_checkpoint`. This path has been deprecated with the new storage refactor. This PR updates the implementations to reconstruct the checkpoint path in `load_checkpoint` and return None in `save_checkpoint` so adhere to the new API.

Signed-off-by: Kai Fricke <[email protected]>
  • Loading branch information
krfricke authored Aug 21, 2023
1 parent 03ae779 commit f2a9ca8
Show file tree
Hide file tree
Showing 15 changed files with 72 additions and 63 deletions.
6 changes: 3 additions & 3 deletions python/ray/tune/examples/bohb_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ def save_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"timestep": self.timestep}))
return path

def load_checkpoint(self, checkpoint_path):
with open(checkpoint_path) as f:
def load_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "r") as f:
self.timestep = json.loads(f.read())["timestep"]


Expand Down
4 changes: 2 additions & 2 deletions python/ray/tune/examples/mnist_pytorch_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ def step(self):
def save_checkpoint(self, checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
torch.save(self.model.state_dict(), checkpoint_path)
return checkpoint_path

def load_checkpoint(self, checkpoint_path):
def load_checkpoint(self, checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
self.model.load_state_dict(torch.load(checkpoint_path))


Expand Down
4 changes: 2 additions & 2 deletions python/ray/tune/examples/pbt_convnet_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def step(self):
def save_checkpoint(self, checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
torch.save(self.model.state_dict(), checkpoint_path)
return checkpoint_path

def load_checkpoint(self, checkpoint_path):
def load_checkpoint(self, checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
self.model.load_state_dict(torch.load(checkpoint_path))

def reset_config(self, new_config):
Expand Down
6 changes: 3 additions & 3 deletions python/ray/tune/examples/pbt_memnn_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,12 +250,12 @@ def step(self):
def save_checkpoint(self, checkpoint_dir):
file_path = checkpoint_dir + "/model"
self.model.save(file_path)
return file_path

def load_checkpoint(self, path):
def load_checkpoint(self, checkpoint_dir):
# See https://stackoverflow.com/a/42763323
del self.model
self.model = load_model(path)
file_path = checkpoint_dir + "/model"
self.model = load_model(file_path)


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions python/ray/tune/examples/pbt_tune_cifar10_with_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,12 @@ def step(self):
def save_checkpoint(self, checkpoint_dir):
file_path = checkpoint_dir + "/model"
self.model.save(file_path)
return file_path

def load_checkpoint(self, path):
def load_checkpoint(self, checkpoint_dir):
# See https://stackoverflow.com/a/42763323
del self.model
self.model = load_model(path)
file_path = checkpoint_dir + "/model"
self.model = load_model(file_path)

def cleanup(self):
# If need, save your model when exit.
Expand Down
6 changes: 3 additions & 3 deletions python/ray/tune/examples/xgboost_dynamic_resources_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,10 @@ def save_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "wb") as outputFile:
pickle.dump((self.config, self.nthread, self.model.save_raw()), outputFile)
return path

def load_checkpoint(self, checkpoint_path):
with open(checkpoint_path, "rb") as inputFile:
def load_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "rb") as inputFile:
self.config, self.nthread, raw_model = pickle.load(inputFile)
self.model = Booster()
self.model.load_model(bytearray(raw_model))
Expand Down
21 changes: 18 additions & 3 deletions python/ray/tune/tests/test_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
upload_to_uri,
delete_at_uri,
)
from ray.train._internal.storage import StorageContext, _use_storage_context
from ray.tune.logger import NoopLogger
from ray.tune.syncer import _DefaultSyncer
from ray.tune.trainable import wrap_function
Expand Down Expand Up @@ -95,15 +96,29 @@ def function_trainable_directory(config):
train.report({"metric": 4}, checkpoint=Checkpoint.from_directory(tmpdir))


@pytest.mark.parametrize("return_type", ["object", "root", "subdir", "checkpoint"])
def test_save_load_checkpoint_path_class(ray_start_2_cpus, return_type):
@pytest.mark.parametrize(
"return_type",
["object", "root"]
# Do not test subdir/checkpoint path in new storage context
+ (["subdir", "checkpoint"] if not _use_storage_context() else []),
)
def test_save_load_checkpoint_path_class(ray_start_2_cpus, return_type, tmpdir):
"""Assert that restoring from a Trainable.save() future works with
class trainables.
Needs Ray cluster so we get actual futures.
"""
trainable = ray.remote(SavingTrainable).remote(return_type=return_type)
trainable = ray.remote(SavingTrainable).remote(
return_type=return_type,
storage=StorageContext(
storage_path=str(tmpdir), experiment_dir_name="test", trial_dir_name="test0"
),
)

# Train one step
ray.get(trainable.train.remote())

# Save checkpoint
saving_future = trainable.save.remote()

# Check for errors
Expand Down
13 changes: 6 additions & 7 deletions python/ray/tune/tests/test_trial_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1669,10 +1669,10 @@ def save_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"iter": self.iter, "replayed": self.replayed}))
return path

def load_checkpoint(self, checkpoint_path):
with open(checkpoint_path) as f:
def load_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "r") as f:
checkpoint_json = json.loads(f.read())
self.iter = checkpoint_json["iter"]
self.replayed = checkpoint_json["replayed"]
Expand Down Expand Up @@ -1841,10 +1841,10 @@ def save_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"iter": self.iter, "replayed": self.replayed}))
return path

def load_checkpoint(self, checkpoint_path):
with open(checkpoint_path) as f:
def load_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "r") as f:
checkpoint_json = json.loads(f.read())
self.iter = checkpoint_json["iter"]
self.replayed = checkpoint_json["replayed"]
Expand Down Expand Up @@ -2026,7 +2026,6 @@ def save_checkpoint(self, path):
checkpoint = os.path.join(path, "checkpoint")
with open(checkpoint, "w") as f:
f.write("OK")
return checkpoint

def reset_config(self, config):
return True
Expand Down
18 changes: 9 additions & 9 deletions python/ray/tune/tests/test_trial_scheduler_pbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,11 @@ def save_checkpoint(self, checkpoint_dir):

with open(file_path, "wb") as fp:
pickle.dump((self.large_object, self.iter, self.a), fp)
return file_path

def load_checkpoint(self, path):
with open(path, "rb") as fp:
def load_checkpoint(self, checkpoint_dir):
file_path = os.path.join(checkpoint_dir, "model.mock")

with open(file_path, "rb") as fp:
self.large_object, self.iter, self.a = pickle.load(fp)

class CheckObjectMemoryUsage(Callback):
Expand Down Expand Up @@ -130,10 +131,11 @@ def save_checkpoint(self, checkpoint_dir):

with open(file_path, "wb") as fp:
pickle.dump((self.iter, self.a), fp)
return file_path

def load_checkpoint(self, path):
with open(path, "rb") as fp:
def load_checkpoint(self, checkpoint_dir):
file_path = os.path.join(checkpoint_dir, "model.mock")

with open(file_path, "rb") as fp:
self.iter, self.a = pickle.load(fp)

from ray.tune.callback import Callback
Expand Down Expand Up @@ -306,9 +308,8 @@ def step(self):

def save_checkpoint(self, checkpoint_dir):
checkpoint = Checkpoint.from_dict({"a": self.a})
checkpoint_path = checkpoint.to_directory(path=checkpoint_dir)
checkpoint.to_directory(path=checkpoint_dir)
time.sleep(self.saving_time)
return checkpoint_path

def load_checkpoint(self, checkpoint_dir):
checkpoint_dict = Checkpoint.from_directory(checkpoint_dir).to_dict()
Expand Down Expand Up @@ -449,7 +450,6 @@ def save_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.mock")
with open(checkpoint_path, "wb") as fp:
pickle.dump((self.a, self.b, self.iter), fp)
return tmp_checkpoint_dir

def load_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.mock")
Expand Down
12 changes: 5 additions & 7 deletions python/ray/tune/tests/test_tune_save_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,12 @@ def step(self):
return {"timesteps_this_iter": 1, "done": True}

def save_checkpoint(self, checkpoint_dir):
checkpoint_path = os.path.join(
checkpoint_dir, "checkpoint-{}".format(self._iteration)
)
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.pkl")
with open(checkpoint_path, "wb") as f:
pickle.dump(self.state, f)
return checkpoint_path

def load_checkpoint(self, checkpoint_path):
def load_checkpoint(self, checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.pkl")
with open(checkpoint_path, "rb") as f:
extra_data = pickle.load(f)
self.state.update(extra_data)
Expand Down Expand Up @@ -91,7 +89,7 @@ def _train(self, exp_name, local_dir, absolute_local_dir):
self.assertTrue(os.path.isdir(abs_trial_dir))
self.assertTrue(
os.path.isfile(
os.path.join(abs_trial_dir, "checkpoint_000001/checkpoint-1")
os.path.join(abs_trial_dir, "checkpoint_000001/checkpoint.pkl")
)
)

Expand All @@ -101,7 +99,7 @@ def _restore(self, exp_name, local_dir, absolute_local_dir):
)

checkpoint_path = os.path.join(
local_dir, exp_name, trial_name, "checkpoint_000001/checkpoint-1"
local_dir, exp_name, trial_name, "checkpoint_000001/checkpoint.pkl"
) # Relative checkpoint path

# The file tune would find. The absolute checkpoint path.
Expand Down
7 changes: 4 additions & 3 deletions python/ray/tune/tests/test_tuner_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,9 +562,10 @@ def save_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"idx": self.idx}))
return path

def load_checkpoint(self, checkpoint_path):
def load_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")

self._is_restored = True
with open(self.tag_file_path, "r") as f:
retried_num = json.loads(f.read())["retried_num"]
Expand All @@ -574,7 +575,7 @@ def load_checkpoint(self, checkpoint_path):

if retried_num < self.retry_num_to_fail:
raise RuntimeError(f"===== Failing restore #{retried_num + 1} =====")
with open(checkpoint_path) as f:
with open(path, "r") as f:
self.idx = json.loads(f.read())["idx"]

# Set environment variable just for this test
Expand Down
19 changes: 8 additions & 11 deletions python/ray/tune/trainable/trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,7 @@ def save(
raise ValueError(
"The returned checkpoint path from `save_checkpoint` "
"must be None or the same as the provided path argument."
f"Got {checkpoint_dict_or_path} != {checkpoint_dir}"
)

local_checkpoint = NewCheckpoint.from_directory(checkpoint_dir)
Expand Down Expand Up @@ -1369,7 +1370,7 @@ def step(self):
"""
raise NotImplementedError

def save_checkpoint(self, checkpoint_dir: str) -> Optional[Union[str, Dict]]:
def save_checkpoint(self, checkpoint_dir: str) -> Optional[Dict]:
"""Subclasses should override this to implement ``save()``.
Warning:
Expand All @@ -1393,11 +1394,9 @@ def save_checkpoint(self, checkpoint_dir: str) -> Optional[Union[str, Dict]]:
the provided path may be temporary and moved.
Returns:
A dict or string. If string, the return value is expected to be
prefixed by `checkpoint_dir`. If dict, the return value will
be automatically serialized by Tune. In both cases, the return value
is exactly what will be passed to ``Trainable.load_checkpoint()``
upon restore.
A dict or None. If dict, the return value will
be automatically serialized by Tune. In that case,
``Trainable.load_checkpoint()`` will receive the dict upon restore.
Example:
>>> trainable, trainable1, trainable2 = ... # doctest: +SKIP
Expand All @@ -1410,7 +1409,7 @@ def save_checkpoint(self, checkpoint_dir: str) -> Optional[Union[str, Dict]]:
"""
raise NotImplementedError

def load_checkpoint(self, checkpoint: Union[Dict, str]):
def load_checkpoint(self, checkpoint: Optional[Dict]):
"""Subclasses should override this to implement restore().
Warning:
Expand Down Expand Up @@ -1460,10 +1459,8 @@ def load_checkpoint(self, checkpoint: Union[Dict, str]):
Args:
checkpoint: If dict, the return value is as
returned by `save_checkpoint`. If a string, then it is
a checkpoint path that may have a different prefix than that
returned by `save_checkpoint`. The directory structure
underneath the `checkpoint_dir` from `save_checkpoint` is preserved.
returned by ``save_checkpoint``. Otherwise, the directory
the checkpoint was stored in.
"""
raise NotImplementedError

Expand Down
6 changes: 3 additions & 3 deletions python/ray/tune/utils/mock_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def save_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"timestep": self.timestep}))
return path

def load_checkpoint(self, checkpoint_path):
with open(checkpoint_path) as f:
def load_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "r") as f:
self.timestep = json.loads(f.read())["timestep"]
1 change: 0 additions & 1 deletion python/ray/tune/utils/release_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def save_checkpoint(self, tmp_checkpoint_dir):
checkpoint_data = np.random.uniform(0, 1, size=self._checkpoint_num_items)
with open(checkpoint_file, "wb") as fp:
pickle.dump(checkpoint_data, fp)
return tmp_checkpoint_dir

def load_checkpoint(self, checkpoint):
pass
Expand Down
6 changes: 3 additions & 3 deletions rllib/algorithms/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ def save_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "mock_agent.pkl")
with open(path, "wb") as f:
pickle.dump(self.info, f)
return path

@override(Algorithm)
def load_checkpoint(self, checkpoint_path):
with open(checkpoint_path, "rb") as f:
def load_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "mock_agent.pkl")
with open(path, "rb") as f:
info = pickle.load(f)
self.info = info
self.restored = True
Expand Down

0 comments on commit f2a9ca8

Please sign in to comment.