Skip to content

Commit

Permalink
Update CompositeCheckpointHandler.metadata() to return StepMetadata.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 704462484
  • Loading branch information
BlaziusMaximus authored and Orbax Authors committed Dec 10, 2024
1 parent ac2d276 commit 29cbf35
Show file tree
Hide file tree
Showing 5 changed files with 248 additions and 37 deletions.
3 changes: 3 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Ignore not-exists and not-dir errors while building step metadata in
_StandardNameFormat.

### Changed
- Return `StepMetadata` from `CompositeCheckpointHandler.metadata()`.

## [0.10.2] - 2024-12-04

### Added
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,8 @@ def finalize(self, directory: epath.Path) -> None:
def close(self):
"""Closes the CheckpointHandler."""
pass

@property
def typestr(self) -> str:
"""A unique identifier for the CheckpointHandler type."""
return f"{self.__module__}.{self.__class__.__qualname__}"
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
from orbax.checkpoint._src.handlers import checkpoint_handler
from orbax.checkpoint._src.handlers import handler_registration
from orbax.checkpoint._src.handlers import proto_checkpoint_handler
from orbax.checkpoint._src.metadata import checkpoint
from orbax.checkpoint._src.metadata import step_metadata_serialization
from orbax.checkpoint._src.path import atomicity
from orbax.checkpoint._src.path import atomicity_defaults
from orbax.checkpoint._src.path import atomicity_types
Expand All @@ -71,6 +73,8 @@
Future = future.Future
CheckpointArgs = checkpoint_args.CheckpointArgs
CheckpointHandler = checkpoint_handler.CheckpointHandler
StepMetadata = checkpoint.StepMetadata
ItemMetadata = checkpoint.ItemMetadata
AsyncCheckpointHandler = async_checkpoint_handler.AsyncCheckpointHandler
register_with_handler = checkpoint_args.register_with_handler
ProtoCheckpointHandler = proto_checkpoint_handler.ProtoCheckpointHandler
Expand Down Expand Up @@ -269,6 +273,7 @@ class CompositeCheckpointHandler(AsyncCheckpointHandler):
the items and handlers. Please use `handler_registry` instead.
Usage::
# The simplest use-case, with no handler registry provided on construction.
checkpointer = ocp.Checkpointer(
ocp.CompositeCheckpointHandler()
Expand All @@ -294,7 +299,7 @@ class CompositeCheckpointHandler(AsyncCheckpointHandler):
checkpointer.save(directory,
ocp.args.Composite(
# Will raise `ValueError: Item "state" and args "JsonSave [...]" does
not match with any registered handler! ...`.
# not match with any registered handler! ...`.
state=ocp.args.JsonSave(pytree),
)
)
Expand Down Expand Up @@ -496,6 +501,8 @@ def __init__(
self._handler_registry,
)

self._metadata_store = checkpoint.metadata_store(enable_write=False)

# TODO: b/359524229 - Remove this property as it has been deprecated.
@property
def _known_handlers(self) -> Dict[str, Optional[CheckpointHandler]]:
Expand Down Expand Up @@ -813,7 +820,7 @@ def restore(
)
return CompositeResults(**restored)

def metadata(self, directory: epath.Path) -> CompositeResults:
def metadata(self, directory: epath.Path) -> StepMetadata:
"""Metadata for each item in the checkpoint.
This has much the same logic as `restore`, in the sense that it tries to
Expand All @@ -825,14 +832,37 @@ def metadata(self, directory: epath.Path) -> CompositeResults:
directory: Path to the checkpoint.
Returns:
CompositeResults
StepMetadata
Raises:
FileNotFoundError: If the directory does not exist.
"""
if not directory.exists():
raise FileNotFoundError(f'Directory does not exist: {directory}')

items_to_handlers = dict(
self._get_all_registered_and_unregistered_items_and_handlers()
)
existing_items = self._existing_items(directory)
try:
existing_items = self._existing_items(directory)
except OSError:
existing_items = []
logging.warning(
'Failed to get existing items from directory %s. Will use items '
'provided during initialization: %s.',
directory, list(items_to_handlers.keys()),
)

serialized_metadata = self._metadata_store.read(
checkpoint.step_metadata_file_path(directory)
)
saved_metadata = step_metadata_serialization.deserialize(
serialized_metadata or {}
)
item_handlers = saved_metadata.item_handlers or {}
item_metadata = dict(saved_metadata.item_metadata or {})
assert item_handlers.keys() == item_metadata.keys()

metadata = {}
for item_name in existing_items:
if (
item_name not in items_to_handlers
Expand All @@ -844,14 +874,25 @@ def metadata(self, directory: epath.Path) -> CompositeResults:
' call `restore` with an appropriate `CheckpointArgs` subclass.',
item_name,
)
metadata[item_name] = None
if item_name not in item_handlers:
item_handlers[item_name] = None
if item_name not in item_metadata:
item_metadata[item_name] = None
continue
handler = items_to_handlers[item_name]
assert handler is not None
metadata[item_name] = handler.metadata(
self._get_item_directory(directory, item_name)
)
return CompositeResults(**metadata)
if item_handlers.get(item_name) is None:
item_handlers[item_name] = handler.typestr
if item_metadata.get(item_name) is None:
item_metadata[item_name] = handler.metadata(
self._get_item_directory(directory, item_name)
)

return dataclasses.replace(
saved_metadata,
item_handlers=item_handlers,
item_metadata=ItemMetadata(**item_metadata),
)

def finalize(self, directory: epath.Path):
if not self._current_temporary_paths:
Expand Down
Loading

0 comments on commit 29cbf35

Please sign in to comment.