Skip to content

Commit

Permalink
[tune] Clean up temporary checkpoint directories for class trainables…
Browse files Browse the repository at this point in the history
… (ex: RLlib) (ray-project#44366)

Clean up temporary checkpoint folders of class trainables after they've been persisted to storage.

---------

Signed-off-by: Justin Yu <[email protected]>
  • Loading branch information
justinvyu authored Apr 19, 2024
1 parent a0b0c9d commit 5d60d8d
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 65 deletions.
5 changes: 0 additions & 5 deletions python/ray/tune/trainable/function_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,6 @@ def save_checkpoint(self, checkpoint_dir: str = ""):
# so `_last_training_result.checkpoint` holds onto the latest ckpt.
return self._last_training_result

def _create_checkpoint_dir(
self, checkpoint_dir: Optional[str] = None
) -> Optional[str]:
return None

def load_checkpoint(self, checkpoint_result: _TrainingResult):
# TODO(justinvyu): This currently breaks the `load_checkpoint` interface.
session = get_session()
Expand Down
132 changes: 72 additions & 60 deletions python/ray/tune/trainable/trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import platform
import shutil
import sys
import tempfile
import time
Expand Down Expand Up @@ -407,16 +408,61 @@ def get_state(self):
"ray_version": ray.__version__,
}

def _create_checkpoint_dir(
self, checkpoint_dir: Optional[str] = None
) -> Optional[str]:
# NOTE: There's no need to supply the checkpoint directory inside
# the local trial dir, since it'll get persisted to the right location.
if checkpoint_dir:
os.makedirs(checkpoint_dir, exist_ok=True)
return checkpoint_dir
def _report_class_trainable_checkpoint(
self, checkpoint_dir: str, checkpoint_dict_or_path: Union[str, Dict]
) -> _TrainingResult:
"""Report a checkpoint saved via Trainable.save_checkpoint.
Need to handle both dict or path checkpoint returned by the user's
`save_checkpoint` method.
This is to get class trainables to work with storage backend used by
function trainables.
This basically re-implements `train.report` for class trainables,
making sure to persist the checkpoint to storage.
"""
if isinstance(checkpoint_dict_or_path, dict):
with Path(checkpoint_dir, _DICT_CHECKPOINT_FILE_NAME).open("wb") as f:
ray_pickle.dump(checkpoint_dict_or_path, f)
elif isinstance(checkpoint_dict_or_path, str):
if checkpoint_dict_or_path != checkpoint_dir:
raise ValueError(
"The returned checkpoint path from `save_checkpoint` "
"must be None or the same as the provided path argument."
f"Got {checkpoint_dict_or_path} != {checkpoint_dir}"
)

local_checkpoint = Checkpoint.from_directory(checkpoint_dir)

metrics = self._last_result.copy() if self._last_result else {}

if self._storage:
# The checkpoint index is updated with the current result.
# NOTE: This is no longer using "iteration" as the folder indexing
# to be consistent with fn trainables.
self._storage._update_checkpoint_index(metrics)

persisted_checkpoint = self._storage.persist_current_checkpoint(
local_checkpoint
)

checkpoint_result = _TrainingResult(
checkpoint=persisted_checkpoint, metrics=metrics
)
# Persist trial artifacts to storage.
self._storage.persist_artifacts(
force=self._storage.sync_config.sync_artifacts_on_checkpoint
)
else:
return tempfile.mkdtemp()
# `storage=None` only happens when initializing the
# Trainable manually, outside of Tune/Train.
# In this case, no storage is set, so the default behavior
# is to just not upload anything and report a local checkpoint.
# This is fine for the main use case of local debugging.
checkpoint_result = _TrainingResult(
checkpoint=local_checkpoint, metrics=metrics
)
return checkpoint_result

@DeveloperAPI
def save(self, checkpoint_dir: Optional[str] = None) -> _TrainingResult:
Expand All @@ -432,64 +478,30 @@ def save(self, checkpoint_dir: Optional[str] = None) -> _TrainingResult:
Note the return value matches up with what is expected of `restore()`.
"""
checkpoint_dir = self._create_checkpoint_dir(checkpoint_dir=checkpoint_dir)

# User saves checkpoint
checkpoint_dict_or_path = self.save_checkpoint(checkpoint_dir)

if not isinstance(self, ray.tune.trainable.FunctionTrainable):
# TODO(justinvyu): [cls_trainable_support]
# This is to get class Trainables to work in the new persistence mode.
# Need to handle checkpoint_dict_or_path == path, dict, or None
# Also need to upload to cloud, since `train.report` never gets called.
if isinstance(checkpoint_dict_or_path, dict):
with Path(checkpoint_dir, _DICT_CHECKPOINT_FILE_NAME).open("wb") as f:
ray_pickle.dump(checkpoint_dict_or_path, f)
elif isinstance(checkpoint_dict_or_path, str):
if checkpoint_dict_or_path != checkpoint_dir:
raise ValueError(
"The returned checkpoint path from `save_checkpoint` "
"must be None or the same as the provided path argument."
f"Got {checkpoint_dict_or_path} != {checkpoint_dir}"
)

local_checkpoint = Checkpoint.from_directory(checkpoint_dir)

metrics = self._last_result.copy() if self._last_result else {}

if self._storage:
# The checkpoint index is updated with the current result.
# NOTE: This is no longer using "iteration" as the folder indexing
# to be consistent with fn trainables.
self._storage._update_checkpoint_index(metrics)

persisted_checkpoint = self._storage.persist_current_checkpoint(
local_checkpoint
)
# Use a temporary directory if no checkpoint_dir is provided.
use_temp_dir = not checkpoint_dir
checkpoint_dir = checkpoint_dir or tempfile.mkdtemp()
os.makedirs(checkpoint_dir, exist_ok=True)

checkpoint_result = _TrainingResult(
checkpoint=persisted_checkpoint, metrics=metrics
)
# Persist trial artifacts to storage.
self._storage.persist_artifacts(
force=self._storage.sync_config.sync_artifacts_on_checkpoint
)
else:
# `storage=None` only happens when initializing the
# Trainable manually, outside of Tune/Train.
# In this case, no storage is set, so the default behavior
# is to just not upload anything and report a local checkpoint.
# This is fine for the main use case of local debugging.
checkpoint_result = _TrainingResult(
checkpoint=local_checkpoint, metrics=metrics
)
checkpoint_dict_or_path = self.save_checkpoint(checkpoint_dir)
checkpoint_result = self._report_class_trainable_checkpoint(
checkpoint_dir, checkpoint_dict_or_path
)

# Clean up the temporary directory, since it's already been
# reported + persisted to storage. If no storage is set, the user is
# running the Trainable locally and is responsible for cleaning
# up the checkpoint directory themselves.
if use_temp_dir and self._storage:
shutil.rmtree(checkpoint_dir, ignore_errors=True)
else:
checkpoint_result: _TrainingResult = checkpoint_dict_or_path
checkpoint_result: _TrainingResult = self.save_checkpoint(None)
assert isinstance(checkpoint_result, _TrainingResult)
assert self._last_result
# Update the checkpoint result to include auto-filled metrics.
checkpoint_result.metrics.update(self._last_result)

assert isinstance(checkpoint_result, _TrainingResult)
return checkpoint_result

@DeveloperAPI
Expand Down

0 comments on commit 5d60d8d

Please sign in to comment.