Skip to content

Commit

Permalink
Fix incorrect usage of is_checkpoint_finalized in Checkpointer. The…
Browse files Browse the repository at this point in the history
… bug was reported when restoring a checkpoint on GCS, the checkpoint was always thought to be "not finalized".

PiperOrigin-RevId: 492206447
  • Loading branch information
cpgaffney1 authored and copybara-github committed Dec 8, 2022
1 parent 390b99d commit 47aa9ca
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 33 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [0.0.23] - 2022-12-08

### Added

- Option to customize metadata file name for Tensorstore.

### Fixed

- Restore failure on GCS due to misidentification of checkpoint as
"not finalized".

## [0.0.22] - 2022-12-05

### Added
Expand Down
2 changes: 1 addition & 1 deletion orbax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
"""Orbax API."""

# A new PyPI release will be pushed everytime `__version__` is increased.
__version__ = '0.0.22'
__version__ = '0.0.23'
2 changes: 1 addition & 1 deletion orbax/checkpoint/async_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def save(self,
if force:
if jax.process_index() == 0:
logging.info('Specified `force`: removing existing directory.')
utils.rmtree(directory) # Post-sync handled by create_tmp_directory.
directory.rmtree() # Post-sync handled by create_tmp_directory.
else:
raise ValueError(f'Destination {directory} already exists.')
tmpdir = utils.create_tmp_directory(directory)
Expand Down
20 changes: 14 additions & 6 deletions orbax/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ def __init__(
elif isinstance(checkpointers, dict):
if METRIC_ITEM_NAME in checkpointers:
raise ValueError(
f'Found {METRIC_ITEM_NAME} in `checkpointers`; this is a reserved key.'
f'Found {METRIC_ITEM_NAME} in `checkpointers`; this is a reserved'
' key.'
)
else:
raise ValueError(
Expand Down Expand Up @@ -311,7 +312,8 @@ def update(self, step: int, metrics: Optional[PyTree] = None):
raise ValueError('Must provide metrics to update.')
if not self._track_best and metrics is not None:
raise ValueError(
'Requested update metrics without configuring the CheckpointManager to track metrics.'
'Requested update metrics without configuring the CheckpointManager'
' to track metrics.'
)

# Wait for ongoing saves to complete. Only applicable if some of the
Expand Down Expand Up @@ -683,7 +685,8 @@ def _cleanup_tmp_directories(self):

def _delete_directory(self, step: int):
if jax.process_index() == 0:
utils.rmtree(self._get_save_directory(step, self.directory))
# TODO(cpgaffney) Optimize tree removal if possible.
self._get_save_directory(step, self.directory).rmtree()

def _remove_old_checkpoints(self):
"""Keeps the `max_to_keep` most recent checkpoint steps."""
Expand All @@ -695,8 +698,10 @@ def _remove_old_checkpoints(self):
return
if self._track_best:
# Best steps (to keep) are at the end, after sorting.
checkpoints_without_metrics, sorted_checkpoints = self._sort_checkpoints_by_metrics(
self._checkpoints)
(
checkpoints_without_metrics,
sorted_checkpoints,
) = self._sort_checkpoints_by_metrics(self._checkpoints)
else:
# checkpoints already sorted by ascending step
checkpoints_without_metrics = []
Expand Down Expand Up @@ -728,7 +733,10 @@ def _remove_old_checkpoints(self):
kept_checkpoints.append(info)
continue

if self._options.keep_period is not None and info.step % self._options.keep_period == 0:
if (
self._options.keep_period is not None
and info.step % self._options.keep_period == 0
):
kept_checkpoints.append(info)
continue

Expand Down
4 changes: 2 additions & 2 deletions orbax/checkpoint/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def save(self,
if force:
if jax.process_index() == 0:
logging.info('Specified `force`: removing existing directory.')
utils.rmtree(directory) # Post-sync handled by create_tmp_directory.
directory.rmtree() # Post-sync handled by create_tmp_directory.
else:
raise ValueError(f'Destination {directory} already exists.')
tmpdir = utils.create_tmp_directory(directory)
Expand All @@ -87,7 +87,7 @@ def restore(self,
directory = epath.Path(directory)
if not directory.exists():
raise FileNotFoundError(f'Checkpoint at {directory} not found.')
if not utils.is_checkpoint_finalized(directory):
if not utils.is_checkpoint_item_finalized(directory):
raise ValueError(f'Found incomplete checkpoint at {directory}.')
logging.info('Restoring item from %s.', directory)
return self._handler.restore(directory, *args, item=item, **kwargs)
Expand Down
47 changes: 47 additions & 0 deletions orbax/checkpoint/checkpointer_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.

"""Common tests for AbstractCheckpointManager subclasses."""
from unittest import mock

from absl.testing import parameterized
from etils import epath
from flax import linen as nn
Expand Down Expand Up @@ -185,3 +187,48 @@ def init_state():

test_utils.assert_tree_equal(self, state.params, restored.params)
test_utils.assert_tree_equal(self, state.opt_state, restored.opt_state)

def test_save_preempted(self):
"""Simulate effects of preemption."""
# Simulates the effects of preemption by creating a tmp directory and
# ensuring it is cleaned up.
tmp_dir = test_utils.save_fake_tmp_dir(
self.directory, 0, 'params', subdirs=['subdir']
)
self.assertTrue(tmp_dir.exists())
self.assertTrue((tmp_dir / 'subdir').exists())

checkpointer = self.checkpointer(PyTreeCheckpointHandler())
with self.assertRaises(ValueError):
checkpointer.restore(tmp_dir)

def test_gcs(self):
"""Test normal operation in simulated GCS environment."""
with mock.patch.object(
utils, 'is_gcs_path', autospec=True, return_value=True
):
checkpointer = self.checkpointer(PyTreeCheckpointHandler())
path = self.directory / '0' / 'params'
checkpointer.save(path, self.pytree)
self.wait_if_async(checkpointer)
restored = checkpointer.restore(
path, restore_args=self.pytree_restore_args
)
test_utils.assert_tree_equal(self, self.pytree, restored)
self.assertTrue((path / utils._COMMIT_SUCCESS_FILE).exists()) # pylint: disable=protected-access

def test_save_preempted_gcs(self):
"""Simulate effects of preemption."""
with mock.patch.object(
utils, 'is_gcs_path', autospec=True, return_value=True
):
tmp_dir = test_utils.save_fake_tmp_dir(
self.directory, 0, 'params', subdirs=['subdir']
)
self.assertTrue(tmp_dir.exists())
self.assertTrue((tmp_dir / 'subdir').exists())

checkpointer = self.checkpointer(PyTreeCheckpointHandler())
with self.assertRaises(ValueError):
checkpointer.restore(tmp_dir)
self.assertFalse((tmp_dir / utils._COMMIT_SUCCESS_FILE).exists()) # pylint: disable=protected-access
114 changes: 91 additions & 23 deletions orbax/checkpoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,6 @@ def is_dict(s):
override=True)


# TODO(cpgaffney) optimize.
def rmtree(path: epath.Path):
"""Recursively removes non-empty directory."""
for child in path.iterdir():
if child.is_file():
child.unlink()
else:
rmtree(child)
path.rmdir()


def leaf_is_placeholder(leaf: Any) -> bool:
"""Determines if `leaf` represents a placeholder for a non-aggregated value.
"""
Expand Down Expand Up @@ -187,7 +176,7 @@ def cleanup_tmp_directories(directory: epath.PathLike):
if jax.process_index() == 0:
tmp_files = tmp_checkpoints(directory)
for tmp_file in tmp_files:
rmtree(directory / tmp_file)
(directory / tmp_file).rmtree()

sync_global_devices('cleanup_tmp_dirs')

Expand Down Expand Up @@ -284,22 +273,85 @@ def is_scalar(x):
return isinstance(x, (int, float, np.number))


def is_checkpoint_finalized(path: epath.PathLike) -> bool:
def is_checkpoint_item_finalized(path: epath.PathLike) -> bool:
"""Determines if the checkpoint item path is finalized.
NOT TO BE CONFUSED WITH is_checkpoint_finalized. That method works on the step
level, while this method works on the item level.
Path takes the form:
<directory>/<step>/<item1>.orbax-checkpoint-tmp-<timestamp>/ # not finalized
# Checkpoint files
...
OR
<directory>/<step>/<item2>/ # finalized
...
Alternatively:
gs://<directory>/<step>/<item1>/ # finalized
commit_success.txt
...
OR
gs://<directory>/<step>/<item2>/ # not finalized
...
Args:
path: Path to item directory.
Returns:
True if the checkpoint item is finalized.
Raises:
ValueError if the provided path is not a directory. Valid checkpoint paths
must be a directory.
"""
path = epath.Path(path)
if not path.is_dir():
raise ValueError(f'Path {path} is not a directory.')
if is_gcs_path(path) and not (path / _COMMIT_SUCCESS_FILE).exists():
return False
if TMP_DIR_SUFFIX in path.name:
return False
return True


def is_checkpoint_step_finalized(path: epath.PathLike) -> bool:
"""Determines if the checkpoint path is finalized.
NOT TO BE CONFUSED WITH is_checkpoint_item_finalized. That method works on the
per-item level, while this method works on the per-step level.
Path takes the form:
<directory>/<step>/
<name1>.orbax-checkpoint-tmp-<timestamp> # not finalized
<name2> # finalized
<item1>.orbax-checkpoint-tmp-<timestamp>/ # not finalized
# Checkpoint files
...
<item2> # finalized
...
Alternatively:
gs://<directory>/<step>/
<name1> # finalized
<item1> # finalized
commit_success.txt
...
<name2> # not finalized
<item2> # not finalized
...
# not finalized
<directory>/checkpoint_<step>.orbax-checkpoint-tmp-<timestamp>/
checkpoint
a/
0.0
.zarray
b/
...
<directory>/checkpoint_<step>/ # finalized
checkpoint
...
Args:
path: Path to step directory.
Expand All @@ -312,23 +364,39 @@ def is_checkpoint_finalized(path: epath.PathLike) -> bool:
"""
path = epath.Path(path)
if not path.is_dir():
raise ValueError(f'Path {path} is not a directory')
raise ValueError(f'Path {path} is not a directory.')
for subpath in path.iterdir():
if is_gcs_path(subpath) and not (subpath / _COMMIT_SUCCESS_FILE).exists():
return False
if TMP_DIR_SUFFIX in subpath.name:
if not is_checkpoint_item_finalized(subpath):
return False
return True


def _is_step_checkpoint(path: epath.Path) -> bool:
"""Determines if the path resembles an Orbax step directory."""
return path.is_dir() and os.fspath(path.name).isdigit()


def checkpoint_steps(checkpoint_dir: epath.PathLike) -> List[int]:
"""Returns a list of finalized checkpoint steps in the directory."""
checkpoint_dir = epath.Path(checkpoint_dir)
return [
int(os.fspath(s.name)) for s in checkpoint_dir.iterdir() if s.is_dir() and
os.fspath(s.name).isdigit() and is_checkpoint_finalized(s)
int(os.fspath(s.name))
for s in checkpoint_dir.iterdir()
if _is_step_checkpoint(s) and is_checkpoint_finalized(s)
]


def is_checkpoint_finalized(path: epath.PathLike) -> bool:
"""Branches to step_finalized/item_finalized depending on the path."""
path = epath.Path(path)
if not path.is_dir():
raise ValueError(f'Checkpoint path {path} must be a directory.')
if _is_step_checkpoint(path):
return is_checkpoint_step_finalized(path)
else:
return is_checkpoint_item_finalized(path)


def tmp_checkpoints(checkpoint_dir: epath.PathLike) -> List[str]:
checkpoint_dir = epath.Path(checkpoint_dir)
return [
Expand Down

0 comments on commit 47aa9ca

Please sign in to comment.