Skip to content

Commit

Permalink
Refactor and restructure constructs from type_handlers.py and `meta…
Browse files Browse the repository at this point in the history
…data`

PiperOrigin-RevId: 698544328
  • Loading branch information
niketkumar authored and Orbax Authors committed Nov 20, 2024
1 parent 71b1ba0 commit 2a39f3a
Show file tree
Hide file tree
Showing 14 changed files with 720 additions and 592 deletions.
2 changes: 2 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ utils into a separate file.
accept and return metadata as dictionaries.
- Move `Checkpointer` implementations to `_src`.
- Add/Update tests for `is_empty_or_leaf` and `is_empty_node`.
- Refactor and restructure constructs from `type_handlers.py` and `metadata`
package to avoid circular dependencies.

## [0.9.1] - 2024-11-11

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@
from orbax.checkpoint import utils
from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.handlers import async_checkpoint_handler
from orbax.checkpoint._src.metadata import empty_values
from orbax.checkpoint._src.metadata import tree as tree_metadata
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.path import format_utils
from orbax.checkpoint._src.serialization import serialization
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
from orbax.checkpoint._src.serialization import type_handlers
from orbax.checkpoint._src.serialization import types
from orbax.checkpoint._src.tree import utils as tree_utils
import tensorstore as ts

Expand Down Expand Up @@ -187,7 +189,7 @@ def _group_value(
requested_restore_type = arg.restore_type or metadata_restore_type
# TODO(cpgaffney): Add a warning message if the requested_restore_type
# is not the same as the metadata_restore_type.
if type_handlers.is_empty_typestr(requested_restore_type):
if empty_values.is_empty_typestr(requested_restore_type):
# Skip deserialization of empty node using TypeHandler.
return
type_for_registry_lookup = requested_restore_type
Expand Down Expand Up @@ -384,7 +386,7 @@ def _param_info(name, value):
ocdbt_target_data_file_size=ocdbt_target_data_file_size,
byte_limiter=byte_limiter,
ts_context=ts_context,
value_typestr=type_handlers.get_param_typestr(
value_typestr=types.get_param_typestr(
value, self._type_handler_registry
),
)
Expand Down Expand Up @@ -565,7 +567,7 @@ async def _maybe_deserialize(
# Add in empty nodes from the metadata tree.
for key in flat_metadata.keys():
if key not in flat_restored:
flat_restored[key] = type_handlers.get_empty_value_from_typestr(
flat_restored[key] = empty_values.get_empty_value_from_typestr(
flat_metadata[key].value_type
)
# Restore using `item` as the target structure. If there are any custom
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.handlers import async_checkpoint_handler
from orbax.checkpoint._src.handlers import base_pytree_checkpoint_handler
from orbax.checkpoint._src.metadata import empty_values
from orbax.checkpoint._src.metadata import tree as tree_metadata
from orbax.checkpoint._src.serialization import serialization
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
Expand Down Expand Up @@ -248,7 +249,7 @@ def _get_param_info(
name: str,
meta_or_value: Union[Any, tree_metadata.ValueMetadataEntry],
) -> Union[ParamInfo, Any]:
if type_handlers.is_supported_empty_value(meta_or_value):
if empty_values.is_supported_empty_value(meta_or_value):
# Empty node, ParamInfo should not be returned.
return meta_or_value
elif not isinstance(meta_or_value, tree_metadata.ValueMetadataEntry):
Expand Down Expand Up @@ -818,7 +819,7 @@ def _maybe_set_default_restore_types(value_meta: Any, arg: RestoreArgs):
if (
isinstance(value_meta, tree_metadata.ValueMetadataEntry)
and not value_meta.skip_deserialize
and value_meta.value_type == type_handlers.RESTORE_TYPE_UNKNOWN
and value_meta.value_type == empty_values.RESTORE_TYPE_UNKNOWN
):
return dataclasses.replace(
value_meta, value_type=type_handlers.default_restore_type(arg)
Expand Down Expand Up @@ -915,23 +916,23 @@ def _get_internal_metadata(
flat_aggregate = None

def _is_empty_value(value):
return type_handlers.is_supported_empty_value(
return empty_values.is_supported_empty_value(
value
) or not utils.leaf_is_placeholder(value)

def _process_aggregate_leaf(value):
if _is_empty_value(value):
return value
return tree_metadata.ValueMetadataEntry(
value_type=type_handlers.RESTORE_TYPE_UNKNOWN,
value_type=empty_values.RESTORE_TYPE_UNKNOWN,
skip_deserialize=False,
)

def _process_metadata_and_aggregate_leaves(value_meta, value):
if _is_empty_value(value):
return value
if type_handlers.is_empty_typestr(value_meta.value_type):
return type_handlers.get_empty_value_from_typestr(value_meta.value_type)
if empty_values.is_empty_typestr(value_meta.value_type):
return empty_values.get_empty_value_from_typestr(value_meta.value_type)
return value_meta

# Handle cases of missing metadata and/or aggregate files.
Expand Down Expand Up @@ -963,11 +964,9 @@ def _process_metadata_and_aggregate_leaves(value_meta, value):
value_meta, flat_aggregate[tuple_key]
)
else:
if type_handlers.is_empty_typestr(value_meta.value_type):
if empty_values.is_empty_typestr(value_meta.value_type):
flat_structure[tuple_key] = (
type_handlers.get_empty_value_from_typestr(
value_meta.value_type
)
empty_values.get_empty_value_from_typestr(value_meta.value_type)
)
else:
flat_structure[tuple_key] = value_meta
Expand Down
75 changes: 75 additions & 0 deletions checkpoint/orbax/checkpoint/_src/metadata/empty_values.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2024 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Handles empty values in the checkpoint PyTree."""

from typing import Any, Mapping
from orbax.checkpoint._src.tree import utils as tree_utils

RESTORE_TYPE_NONE = 'None'
RESTORE_TYPE_DICT = 'Dict'
RESTORE_TYPE_LIST = 'List'
RESTORE_TYPE_TUPLE = 'Tuple'
RESTORE_TYPE_UNKNOWN = 'Unknown'
# TODO: b/365169723 - Handle empty NamedTuple.


# TODO: b/365169723 - Handle empty NamedTuple.
def is_supported_empty_value(value: Any) -> bool:
"""Determines if the *empty* `value` is supported without custom TypeHandler."""
# Check isinstance first to avoid `not` checks on jax.Arrays (raises error).
if tree_utils.isinstance_of_namedtuple(value):
return False
return (
isinstance(value, (dict, list, tuple, type(None), Mapping)) and not value
)


# TODO: b/365169723 - Handle empty NamedTuple.
def get_empty_value_typestr(value: Any) -> str:
"""Returns the typestr constant for the empty value."""
if not is_supported_empty_value(value):
raise ValueError(f'{value} is not a supported empty type.')
if isinstance(value, list):
return RESTORE_TYPE_LIST
if isinstance(value, tuple):
return RESTORE_TYPE_TUPLE
if isinstance(value, (dict, Mapping)):
return RESTORE_TYPE_DICT
if value is None:
return RESTORE_TYPE_NONE
raise ValueError(f'Unrecognized empty type: {value}.')


# TODO: b/365169723 - Handle empty NamedTuple.
def is_empty_typestr(typestr: str) -> bool:
return (
typestr == RESTORE_TYPE_LIST
or typestr == RESTORE_TYPE_TUPLE
or typestr == RESTORE_TYPE_DICT
or typestr == RESTORE_TYPE_NONE
)


# TODO: b/365169723 - Handle empty NamedTuple.
def get_empty_value_from_typestr(typestr: str) -> Any:
if typestr == RESTORE_TYPE_LIST:
return []
if typestr == RESTORE_TYPE_TUPLE:
return tuple()
if typestr == RESTORE_TYPE_DICT:
return {}
if typestr == RESTORE_TYPE_NONE:
return None
raise ValueError(f'Unrecognized typestr: {typestr}.')
44 changes: 44 additions & 0 deletions checkpoint/orbax/checkpoint/_src/metadata/empty_values_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2024 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for empty values as leafs in the checkpoint tree."""

from absl.testing import absltest
from absl.testing import parameterized
from orbax.checkpoint._src.metadata import empty_values
from orbax.checkpoint._src.testing import test_tree_utils


class EmptyValuesTest(parameterized.TestCase):

@parameterized.parameters(
(1, False),
(dict(), True),
({}, True),
({"a": {}}, False),
([], True),
([[]], False),
(None, True),
((1, 2), False),
(test_tree_utils.EmptyNamedTuple(), False),
(test_tree_utils.MuNu(mu=None, nu=None), False),
(test_tree_utils.NamedTupleWithNestedAttributes(), False),
(test_tree_utils.NamedTupleWithNestedAttributes(nested_dict={}), False),
)
def test_is_supported_empty_value(self, value, expected):
self.assertEqual(expected, empty_values.is_supported_empty_value(value))


if __name__ == "__main__":
absltest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2024 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Basic constructs for PyTree Metadata handling."""

import dataclasses


@dataclasses.dataclass(kw_only=True)
class PyTreeMetadataOptions:
"""Options for managing PyTree metadata.
Attributes:
support_rich_types: [Experimental feature: subject to change without
notice.] If True, supports NamedTuple and Tuple node types in the
metadata. Otherwise, a NamedTuple node is converted to dict and Tuple node
to list.
"""

# TODO: b/365169723 - Support different namedtuple ser/deser strategies.

support_rich_types: bool


# Global default options.
PYTREE_METADATA_OPTIONS = PyTreeMetadataOptions(support_rich_types=False)
Loading

0 comments on commit 2a39f3a

Please sign in to comment.