-
Notifications
You must be signed in to change notification settings - Fork 38
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor and restructure constructs from
type_handlers.py
and `meta…
…data` PiperOrigin-RevId: 698544328
- Loading branch information
1 parent
71b1ba0
commit 2a39f3a
Showing
14 changed files
with
720 additions
and
592 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Copyright 2024 The Orbax Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Handles empty values in the checkpoint PyTree.""" | ||
|
||
from typing import Any, Mapping | ||
from orbax.checkpoint._src.tree import utils as tree_utils | ||
|
||
RESTORE_TYPE_NONE = 'None' | ||
RESTORE_TYPE_DICT = 'Dict' | ||
RESTORE_TYPE_LIST = 'List' | ||
RESTORE_TYPE_TUPLE = 'Tuple' | ||
RESTORE_TYPE_UNKNOWN = 'Unknown' | ||
# TODO: b/365169723 - Handle empty NamedTuple. | ||
|
||
|
||
# TODO: b/365169723 - Handle empty NamedTuple. | ||
def is_supported_empty_value(value: Any) -> bool: | ||
"""Determines if the *empty* `value` is supported without custom TypeHandler.""" | ||
# Check isinstance first to avoid `not` checks on jax.Arrays (raises error). | ||
if tree_utils.isinstance_of_namedtuple(value): | ||
return False | ||
return ( | ||
isinstance(value, (dict, list, tuple, type(None), Mapping)) and not value | ||
) | ||
|
||
|
||
# TODO: b/365169723 - Handle empty NamedTuple. | ||
def get_empty_value_typestr(value: Any) -> str: | ||
"""Returns the typestr constant for the empty value.""" | ||
if not is_supported_empty_value(value): | ||
raise ValueError(f'{value} is not a supported empty type.') | ||
if isinstance(value, list): | ||
return RESTORE_TYPE_LIST | ||
if isinstance(value, tuple): | ||
return RESTORE_TYPE_TUPLE | ||
if isinstance(value, (dict, Mapping)): | ||
return RESTORE_TYPE_DICT | ||
if value is None: | ||
return RESTORE_TYPE_NONE | ||
raise ValueError(f'Unrecognized empty type: {value}.') | ||
|
||
|
||
# TODO: b/365169723 - Handle empty NamedTuple. | ||
def is_empty_typestr(typestr: str) -> bool: | ||
return ( | ||
typestr == RESTORE_TYPE_LIST | ||
or typestr == RESTORE_TYPE_TUPLE | ||
or typestr == RESTORE_TYPE_DICT | ||
or typestr == RESTORE_TYPE_NONE | ||
) | ||
|
||
|
||
# TODO: b/365169723 - Handle empty NamedTuple. | ||
def get_empty_value_from_typestr(typestr: str) -> Any: | ||
if typestr == RESTORE_TYPE_LIST: | ||
return [] | ||
if typestr == RESTORE_TYPE_TUPLE: | ||
return tuple() | ||
if typestr == RESTORE_TYPE_DICT: | ||
return {} | ||
if typestr == RESTORE_TYPE_NONE: | ||
return None | ||
raise ValueError(f'Unrecognized typestr: {typestr}.') |
44 changes: 44 additions & 0 deletions
44
checkpoint/orbax/checkpoint/_src/metadata/empty_values_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Copyright 2024 The Orbax Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Tests for empty values as leafs in the checkpoint tree.""" | ||
|
||
from absl.testing import absltest | ||
from absl.testing import parameterized | ||
from orbax.checkpoint._src.metadata import empty_values | ||
from orbax.checkpoint._src.testing import test_tree_utils | ||
|
||
|
||
class EmptyValuesTest(parameterized.TestCase): | ||
|
||
@parameterized.parameters( | ||
(1, False), | ||
(dict(), True), | ||
({}, True), | ||
({"a": {}}, False), | ||
([], True), | ||
([[]], False), | ||
(None, True), | ||
((1, 2), False), | ||
(test_tree_utils.EmptyNamedTuple(), False), | ||
(test_tree_utils.MuNu(mu=None, nu=None), False), | ||
(test_tree_utils.NamedTupleWithNestedAttributes(), False), | ||
(test_tree_utils.NamedTupleWithNestedAttributes(nested_dict={}), False), | ||
) | ||
def test_is_supported_empty_value(self, value, expected): | ||
self.assertEqual(expected, empty_values.is_supported_empty_value(value)) | ||
|
||
|
||
if __name__ == "__main__": | ||
absltest.main() |
37 changes: 37 additions & 0 deletions
37
checkpoint/orbax/checkpoint/_src/metadata/pytree_metadata_options.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Copyright 2024 The Orbax Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Basic constructs for PyTree Metadata handling.""" | ||
|
||
import dataclasses | ||
|
||
|
||
@dataclasses.dataclass(kw_only=True) | ||
class PyTreeMetadataOptions: | ||
"""Options for managing PyTree metadata. | ||
Attributes: | ||
support_rich_types: [Experimental feature: subject to change without | ||
notice.] If True, supports NamedTuple and Tuple node types in the | ||
metadata. Otherwise, a NamedTuple node is converted to dict and Tuple node | ||
to list. | ||
""" | ||
|
||
# TODO: b/365169723 - Support different namedtuple ser/deser strategies. | ||
|
||
support_rich_types: bool | ||
|
||
|
||
# Global default options. | ||
PYTREE_METADATA_OPTIONS = PyTreeMetadataOptions(support_rich_types=False) |
Oops, something went wrong.