Skip to content

Commit

Permalink
Minor fixes to metadata serialization.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 700721116
  • Loading branch information
BlaziusMaximus authored and Orbax Authors committed Nov 27, 2024
1 parent 8a34b74 commit 6a66183
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 30 deletions.
22 changes: 13 additions & 9 deletions checkpoint/orbax/checkpoint/_src/metadata/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,25 +35,29 @@
SerializedMetadata = TypeVar('SerializedMetadata', bound=dict[str, Any])


def step_metadata_file_path(path: epath.PathLike) -> epath.Path:
"""The path to step metadata file for a given checkpoint directory."""
def _sanitize_metadata_path(path: epath.PathLike) -> epath.Path:
"""Sanitizes the path and returns it as an `epath.Path`."""
path = epath.Path(path)
if not path.exists():
raise FileNotFoundError(f'Path does not exist: {path}')
if not path.is_dir():
raise ValueError(f'Path is not a directory: {path}')
return path / _STEP_METADATA_FILENAME
raise NotADirectoryError(f'Path is not a directory: {path}')
return path


def step_metadata_file_path(path: epath.PathLike) -> epath.Path:
"""The path to step metadata file for a given checkpoint directory."""
return _sanitize_metadata_path(path) / _STEP_METADATA_FILENAME


def root_metadata_file_path(
path: epath.PathLike, *, legacy: bool = False
) -> epath.Path:
"""The path to root metadata file for a given checkpoint directory."""
path = epath.Path(path)
if not path.is_dir():
raise ValueError(f'Path is not a directory: {path}')
filename = (
_LEGACY_ROOT_METADATA_FILENAME if legacy else _ROOT_METADATA_FILENAME
)
return path / filename
return _sanitize_metadata_path(path) / filename


@dataclasses.dataclass
Expand Down Expand Up @@ -97,7 +101,7 @@ class RootMetadata:
"""

format: str | None = None
custom: dict[str, Any] = dataclasses.field(default_factory=dict)
custom: dict[str, Any] | None = dataclasses.field(default_factory=dict)


class MetadataStore(Protocol):
Expand Down
74 changes: 69 additions & 5 deletions checkpoint/orbax/checkpoint/_src/metadata/checkpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,11 @@ def test_read_unknown_path(self, blocking_write: bool):
checkpoint.step_metadata_file_path,
checkpoint.root_metadata_file_path,
)
def test_unkown_metadata_path(
def test_metadata_path_does_not_exist(
self, file_path_fn: Callable[[epath.PathLike], epath.Path]
):
with self.assertRaisesRegex(ValueError, 'Path is not a directory'):
file_path_fn('unknown_metadata_path')
with self.assertRaisesRegex(FileNotFoundError, 'Path does not exist'):
file_path_fn('non_existent_metadata_path')

def test_legacy_root_metadata_file_path(self):
self.assertEqual(
Expand Down Expand Up @@ -475,15 +475,15 @@ def test_metadata_file_path(
self.directory / self.get_metadata_filename(metadata_class),
)

with self.assertRaisesRegex(ValueError, 'Path is not a directory'):
with self.assertRaisesRegex(IOError, 'Path does not exist'):
self.get_metadata_file_path(metadata_class, path=self.directory / '0')

metadata_file = self.get_metadata_file_path(metadata_class)
self.write_metadata_store(blocking_write=True).write(
file_path=metadata_file,
metadata=self.serialize_metadata(self.get_metadata(metadata_class)),
)
with self.assertRaisesRegex(ValueError, 'Path is not a directory'):
with self.assertRaisesRegex(IOError, 'Path is not a directory'):
self.get_metadata_file_path(metadata_class, path=metadata_file)

@parameterized.parameters(True, False)
Expand All @@ -492,6 +492,70 @@ def test_pickle(self, blocking_write: bool):
_ = pickle.dumps(self.write_metadata_store(blocking_write))
_ = pickle.dumps(self.read_metadata_store(blocking_write))

@parameterized.parameters(
({'format': int()},),
({'custom': list()},),
({'custom': {int(): None}},),
)
def test_deserialize_wrong_types_root_metadata(
self, wrong_metadata: checkpoint.SerializedMetadata
):
with self.assertRaises(ValueError):
self.deserialize_metadata(RootMetadata, wrong_metadata)

@parameterized.parameters(
({'format': int()},),
({'item_handlers': list()},),
({'item_handlers': {int(): None}},),
({'item_metadata': list()},),
({'item_metadata': {int(): None}},),
({'metrics': list()},),
({'metrics': {int(): None}},),
({'performance_metrics': list()},),
({'performance_metrics': {int(): float()}},),
({'performance_metrics': {str(): int()}},),
({'init_timestamp_nsecs': float()},),
({'commit_timestamp_nsecs': float()},),
({'custom': list()},),
({'custom': {int(): None}},),
)
def test_deserialize_wrong_types_step_metadata(
self, wrong_metadata: checkpoint.SerializedMetadata
):
with self.assertRaises(ValueError):
self.deserialize_metadata(StepMetadata, wrong_metadata)

@parameterized.parameters(
(
RootMetadata(custom={'a': None}),
{'custom': {'a': None}}
),
(
RootMetadata(format=_SAMPLE_FORMAT),
{'format': _SAMPLE_FORMAT}
),
(
StepMetadata(format=_SAMPLE_FORMAT),
{'format': _SAMPLE_FORMAT},
),
(
StepMetadata(item_handlers={'a': 'a_handler'}),
{'item_handlers': {'a': 'a_handler'}},
),
(
StepMetadata(custom={'blah': 123}),
{'custom': {'blah': 123}},
),
)
def test_only_serialize_non_default_metadata_values(
self,
metadata: StepMetadata | RootMetadata,
expected_serialized_metadata: dict[str, Any],
):
self.assertEqual(
self.serialize_metadata(metadata), expected_serialized_metadata
)


if __name__ == '__main__':
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@

def serialize(metadata: RootMetadata) -> SerializedMetadata:
"""Serializes `metadata` to a dictionary."""
return {
'format': metadata.format,
'custom': metadata.custom,
}
serialized_metadata = {}
if metadata.format is not None:
serialized_metadata['format'] = metadata.format
if metadata.custom:
serialized_metadata['custom'] = metadata.custom
return serialized_metadata


def deserialize(metadata_dict: SerializedMetadata) -> RootMetadata:
Expand All @@ -40,7 +42,7 @@ def deserialize(metadata_dict: SerializedMetadata) -> RootMetadata:

if 'custom' in metadata_dict:
utils.validate_field(metadata_dict, 'custom', dict)
for k in metadata_dict['custom']:
for k in metadata_dict.get('custom', {}) or {}:
utils.validate_dict_entry(metadata_dict, 'custom', k, str)
validated_metadata_dict['custom'] = metadata_dict.get('custom', {})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,27 @@ def serialize(metadata: StepMetadata) -> SerializedMetadata:
if isinstance(val, float)
}

return {
'format': metadata.format,
'item_handlers': metadata.item_handlers,
'item_metadata': just_item_names,
'metrics': metadata.metrics,
'performance_metrics': float_metrics,
'init_timestamp_nsecs': metadata.init_timestamp_nsecs,
'commit_timestamp_nsecs': metadata.commit_timestamp_nsecs,
'custom': metadata.custom,
}
serialized_metadata = {}
if metadata.format is not None:
serialized_metadata['format'] = metadata.format
if metadata.item_handlers:
serialized_metadata['item_handlers'] = metadata.item_handlers
if just_item_names is not None:
serialized_metadata['item_metadata'] = just_item_names
if metadata.metrics:
serialized_metadata['metrics'] = metadata.metrics
if float_metrics:
serialized_metadata['performance_metrics'] = float_metrics
if metadata.init_timestamp_nsecs is not None:
serialized_metadata['init_timestamp_nsecs'] = metadata.init_timestamp_nsecs
if metadata.commit_timestamp_nsecs is not None:
serialized_metadata['commit_timestamp_nsecs'] = (
metadata.commit_timestamp_nsecs
)
if metadata.custom:
serialized_metadata['custom'] = metadata.custom

return serialized_metadata


def deserialize(
Expand All @@ -70,7 +81,7 @@ def deserialize(

utils.validate_field(metadata_dict, 'item_handlers', dict)
for k in metadata_dict.get('item_handlers', {}) or {}:
utils.validate_dict_entry(metadata_dict, 'item_handlers', k, str, str)
utils.validate_dict_entry(metadata_dict, 'item_handlers', k, str)
validated_metadata_dict['item_handlers'] = metadata_dict.get(
'item_handlers', {}
)
Expand Down

0 comments on commit 6a66183

Please sign in to comment.