Skip to content

Commit

Permalink
Rename _single_item to _default_item in CheckpointManager.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 684673200
  • Loading branch information
BlaziusMaximus authored and Orbax Authors committed Oct 11, 2024
1 parent 7ffdc88 commit fef8e18
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 53 deletions.
3 changes: 3 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Add a `SaveArgs` option that allows disabling pinned host transfer on a per-array basis.

### Changed
- Rename `CheckpointManager._single_item` to `CheckpointManager._default_item`.

## [0.7.0] - 2024-10-07

### Removed
Expand Down
90 changes: 37 additions & 53 deletions checkpoint/orbax/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def _create_root_directory(
)


def _determine_single_item_mode_from_args(
def _determine_default_item_mode_from_args(
args: args_lib.CheckpointArgs,
) -> bool:
if isinstance(args, args_lib.Composite):
Expand All @@ -422,7 +422,7 @@ def _determine_single_item_mode_from_args(
return True


def _determine_single_item_mode_from_directory(step_path: epath.Path) -> bool:
def _determine_default_item_mode_from_directory(step_path: epath.Path) -> bool:
return (step_path / DEFAULT_ITEM_NAME).exists()


Expand Down Expand Up @@ -460,10 +460,9 @@ def __init__(
Example::
# Multiple items.
with CheckpointManager(
'path/to/dir/',
# Multiple items.
item_names=('train_state', 'custom_metadata'),
metadata={'version': 1.1, 'lang': 'en'},
) as mngr:
mngr.save(0, args=args.Composite(
Expand All @@ -481,7 +480,7 @@ def __init__(
print(restored.train_state)
print(restored.custom_metadata) # Error, not restored
# Single item, no need to specify `item_names`.
# Single, unnamed (default) item.
with CheckpointManager(
'path/to/dir/',
options = CheckpointManagerOptions(max_to_keep=5, ...),
Expand Down Expand Up @@ -518,14 +517,7 @@ def __init__(
Args:
directory: the top level directory in which to save all files.
checkpointers: a mapping of object name to Checkpointer object. For
example, `items` provided to `save` below should have keys matching the
keys in this argument. Alternatively, a single Checkpointer may be
provided, in which case `save` and `restore` should always be called
with a single item rather than a dictionary of items. See below for more
details. `item_names` and `checkpointers` are mutually exclusive - do
not use together. Also, please don't use `checkpointers` and
`item_handlers` together.
checkpointers: deprecated, do not use. use `handler_registry` instead.
options: CheckpointManagerOptions. May be provided to specify additional
arguments. If None, uses default values of CheckpointManagerOptions.
metadata: High-level metadata that does not depend on step number. If
Expand All @@ -536,16 +528,8 @@ def __init__(
expected. A CheckpointManager instance with a read-only `directory` uses
the metadata if already present, otherwise always uses the current given
metadata.
item_names: Names of distinct items that may be saved/restored with this
`CheckpointManager`. `item_names` and `checkpointers` are mutually
exclusive - do not use together. Also see `item_handlers` below.
item_handlers: A mapping of item name to `CheckpointHandler`. The mapped
CheckpointHandler must be registered against the `CheckpointArgs` input
in save/restore operations. Please don't use `checkpointers` and
`item_handlers` together. It can be used with or without `item_names`.
The item name key may or may not be present in `item_names`.
Alternatively, a single CheckpointHandler may be provided, in which case
`save` and `restore` should always be called in a single item context.
item_names: deprecated, do not use. use `handler_registry` instead.
item_handlers: deprecated, do not use. use `handler_registry` instead.
logger: A logger to log checkpointing events.
handler_registry: A registry of handlers to use for checkpointing. This
option is mutually exclusive with `checkpointers`,`item_handlers`, and
Expand Down Expand Up @@ -584,7 +568,7 @@ def __init__(
)
if item_names and isinstance(item_handlers, CheckpointHandler):
raise ValueError(
'`item_handlers` in single item mode and `item_names` should not be'
'`item_handlers` in default item mode and `item_names` should not be'
' provided together.'
)
if checkpointers is not None and handler_registry is not None:
Expand Down Expand Up @@ -628,38 +612,38 @@ def __init__(
' https://orbax.readthedocs.io/en/latest/api_refactor.html to'
' migrate.'
)
self._single_item = isinstance(checkpointers, AbstractCheckpointer)
self._default_item = isinstance(checkpointers, AbstractCheckpointer)
self._checkpointer = self._configure_checkpointer_legacy_init(
checkpointers, self._options
)
elif handler_registry is not None:
# There is no way to know if this is a single item or not, detemine this
# lazily instead on the first call to `save`, `restore` or
# `item_metadata`. Once locked-in, the value of `_single_item` will not
# change.
self._single_item = None
# There is no way to know if this is a single, unnamed (default) item or
# not, detemine this lazily instead on the first call to `save`, `restore`
# or `item_metadata`. Once locked-in, the value of `_default_item` will
# not change.
self._default_item = None
self._checkpointer = self._configure_checkpointer_from_handler_registry(
handler_registry,
self._options,
)
elif item_names is None and item_handlers is None:
# In this case, we can just default construct the
# CheckpointHandlerRegistry and allow the user to lazily specify single
# CheckpointHandlerRegistry and allow the user to lazily specify default
# vs. multi-item mode.
self._single_item = None
self._default_item = None
handler_registry = handler_registration.DefaultCheckpointHandlerRegistry()
self._checkpointer = self._configure_checkpointer_from_handler_registry(
handler_registry,
self._options,
)
else:
self._single_item = isinstance(item_handlers, CheckpointHandler)
self._default_item = isinstance(item_handlers, CheckpointHandler)
self._checkpointer = (
self._configure_checkpointer_from_item_names_and_handlers(
item_names,
item_handlers,
self._options,
self._single_item,
self._default_item,
)
)

Expand Down Expand Up @@ -870,7 +854,7 @@ def _configure_checkpointer_from_item_names_and_handlers(
item_names: Optional[Sequence[str]],
item_handlers: Optional[Union[CheckpointHandler, CheckpointHandlersDict]],
options: CheckpointManagerOptions,
single_item: bool,
default_item: bool,
) -> Checkpointer:
"""Initializes _CompositeCheckpointer given `item_names`."""
if (
Expand All @@ -881,7 +865,7 @@ def _configure_checkpointer_from_item_names_and_handlers(
'When primary_host is set to None, item_handlers must be provided to'
' match with the primary_host setting.'
)
if single_item:
if default_item:
item_handler = (
item_handlers
if isinstance(item_handlers, CheckpointHandler)
Expand Down Expand Up @@ -1088,11 +1072,11 @@ def _validate_args(
raise ValueError(
f'Expected args of type `CheckpointArgs`; found {type(args)}.'
)
if self._single_item:
if self._default_item:
if isinstance(args, args_lib.Composite):
raise ValueError(
'Cannot provide `args` of type `Composite` when dealing with a'
' single checkpointable object.'
' single, unnamed (default) checkpointable object.'
)
else:
if not isinstance(args, args_lib.Composite):
Expand All @@ -1119,8 +1103,8 @@ def save(

if items is None and args is None:
raise ValueError('Must provide `args` for `save`.')
if self._single_item is None:
self._single_item = _determine_single_item_mode_from_args(args)
if self._default_item is None:
self._default_item = _determine_default_item_mode_from_args(args)
self._validate_args(items, args)
if not force and not self.should_save(step):
return False
Expand Down Expand Up @@ -1156,7 +1140,7 @@ def save(
items = {}
if save_kwargs is None:
save_kwargs = {}
if self._single_item:
if self._default_item:
items = {DEFAULT_ITEM_NAME: items}
save_kwargs = {DEFAULT_ITEM_NAME: save_kwargs}

Expand All @@ -1172,10 +1156,10 @@ def save(
)
extra_args = save_kwargs[key] if key in save_kwargs else {}
extra_args = extra_args or {}
args_dict[key] = save_ckpt_arg_cls(item, **extra_args) # pytype: disable=wrong-arg-count
args_dict[key] = save_ckpt_arg_cls(item, **extra_args) # pylint: disable=too-many-function-args # pytype: disable=wrong-arg-count
args = args_lib.Composite(**args_dict)
else:
if self._single_item:
if self._default_item:
args = args_lib.Composite(**{DEFAULT_ITEM_NAME: args})
else:
if not isinstance(args, args_lib.Composite):
Expand Down Expand Up @@ -1302,7 +1286,7 @@ def save(
return True

def _maybe_get_default_item(self, composite_result: args_lib.Composite):
if self._single_item:
if self._default_item:
if DEFAULT_ITEM_NAME not in composite_result:
raise ValueError(
'Unable to retrieve default item. Please ensure that a handler for'
Expand Down Expand Up @@ -1332,19 +1316,19 @@ def restore(
step_stats.checkpoint_manager_start_time = time.time()
step_stats.directory = str(directory)

if self._single_item is None:
self._single_item = _determine_single_item_mode_from_directory(
if self._default_item is None:
self._default_item = _determine_default_item_mode_from_directory(
self._get_read_step_directory(step, directory)
)
self._validate_args(items, args)

if items is None:
items = {}
elif self._single_item:
elif self._default_item:
items = {DEFAULT_ITEM_NAME: items}
if restore_kwargs is None:
restore_kwargs = {}
elif self._single_item:
elif self._default_item:
restore_kwargs = {DEFAULT_ITEM_NAME: restore_kwargs}

if args is None:
Expand All @@ -1358,10 +1342,10 @@ def restore(
item = items[key] if key in items else None
extra_args = restore_kwargs[key] if key in restore_kwargs else {}
extra_args = extra_args or {}
args_dict[key] = restore_ckpt_arg_cls(item, **extra_args) # pytype: disable=wrong-arg-count
args_dict[key] = restore_ckpt_arg_cls(item, **extra_args) # pylint: disable=too-many-function-args # pytype: disable=wrong-arg-count
args = args_lib.Composite(**args_dict)
else:
if self._single_item:
if self._default_item:
args = args_lib.Composite(**{DEFAULT_ITEM_NAME: args})
else:
args = typing.cast(args_lib.Composite, args)
Expand Down Expand Up @@ -1392,15 +1376,15 @@ def item_metadata(self, step: int) -> Union[Any, args_lib.Composite]:
step: The step to retrieve metadata for.
Returns:
Either metadata for the item itself, if in single-item mode, or a
Either metadata for the item itself, if in default-item mode, or a
Composite of metadata for each item.
"""
assert isinstance(self._checkpointer.handler, CompositeCheckpointHandler)
read_step_directory = self._get_read_step_directory(step, self.directory)

result = self._checkpointer.metadata(read_step_directory)
if self._single_item is None:
self._single_item = _determine_single_item_mode_from_directory(
if self._default_item is None:
self._default_item = _determine_default_item_mode_from_directory(
read_step_directory
)
return self._maybe_get_default_item(result)
Expand Down

0 comments on commit fef8e18

Please sign in to comment.