Skip to content

Commit

Permalink
Wrap tree.utils.serialize_tree -> metadata.tree.serialize_tree to…
Browse files Browse the repository at this point in the history
… allow `NamedTuple` and `Tuple` node types in PyTree metadata.

PiperOrigin-RevId: 697719927
  • Loading branch information
niketkumar authored and Orbax Authors committed Nov 18, 2024
1 parent ffa7e69 commit 0e737f0
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 50 deletions.
2 changes: 2 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ that contain utilities to perform de/serialization for `RootMetadata` and
- `ReplicaSlice`/`ReplicaSlices` construct to facilitate saving replicated
arrays.
- Added restoring with custom jax.experimental.layout.Layout support
- Add experimental `PyTreeMetadataOptions` to manage rich types in pytree
checkpoint metadata.

### Changed
- Refactor metadata/tree_test.py and move common test types to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,9 @@ def __init__(
multiprocessing_options: options_lib.MultiprocessingOptions = options_lib.MultiprocessingOptions(),
type_handler_registry: TypeHandlerRegistry = type_handlers.GLOBAL_TYPE_HANDLER_REGISTRY,
enable_post_merge_validation: bool = True,
pytree_metadata_options: tree_metadata.PyTreeMetadataOptions = (
tree_metadata.PYTREE_METADATA_OPTIONS
),
):
"""Creates BasePyTreeCheckpointHandler.
Expand All @@ -308,6 +311,7 @@ def __init__(
enable_descriptor: If True, logs a Descriptor proto that contains lineage
enable_post_merge_validation: If True, enables validation of the
parameters after the finalize step.
pytree_metadata_options: `PyTreeMetadataOptions` to manage metadata.
"""
self._save_concurrent_bytes = save_concurrent_bytes
self._restore_concurrent_bytes = restore_concurrent_bytes
Expand All @@ -316,6 +320,7 @@ def __init__(
self._primary_host = multiprocessing_options.primary_host
self._type_handler_registry = type_handler_registry
self._enable_post_merge_validation = enable_post_merge_validation
self._pytree_metadata_options = pytree_metadata_options


jax.monitoring.record_event(
Expand Down Expand Up @@ -686,8 +691,8 @@ class TrainState:
restore_args = _fill_missing_save_or_restore_args(
item, restore_args, mode='restore'
)
restore_args = tree_utils.serialize_tree(
restore_args, keep_empty_nodes=True
restore_args = tree_metadata.serialize_tree(
restore_args, self._pytree_metadata_options
)
param_infos = self._get_param_infos(
item=value_metadata_tree,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def _get_restore_parameters(
param_names: Optional[PyTree],
transforms: Optional[PyTree],
restore_args: Optional[PyTree],
pytree_metadata_options: tree_metadata.PyTreeMetadataOptions,
byte_limiter: Optional[LimitInFlightBytes] = None,
transforms_default_to_original: bool = True,
use_zarr3: bool = False,
Expand Down Expand Up @@ -221,6 +222,7 @@ def _get_restore_parameters(
restore_args: User-provided restoration arguments. If None, they were not
provided. Otherwise, the tree has the same structure as the desired output
tree.
pytree_metadata_options: `PyTreeMetadataOptions` to manage metadata.
byte_limiter: A LimitInFlightBytes object.
transforms_default_to_original: See transform_utils.apply_transformations.
use_zarr3: If True, use Zarr ver3 otherwise Zarr ver2
Expand Down Expand Up @@ -268,9 +270,8 @@ def _get_param_info(
if transforms is None:
for key, meta in flat_structure.items():
flat_param_infos[key] = _get_param_info(flat_param_names[key], meta)
restore_args = tree_utils.serialize_tree(
restore_args,
keep_empty_nodes=True,
restore_args = tree_metadata.serialize_tree(
restore_args, pytree_metadata_options
)
else:
if item is None:
Expand Down Expand Up @@ -465,6 +466,9 @@ def __init__(
multiprocessing_options: options_lib.MultiprocessingOptions = options_lib.MultiprocessingOptions(),
type_handler_registry: TypeHandlerRegistry = type_handlers.GLOBAL_TYPE_HANDLER_REGISTRY,
handler_impl: Optional[BasePyTreeCheckpointHandler] = None,
pytree_metadata_options: tree_metadata.PyTreeMetadataOptions = (
tree_metadata.PYTREE_METADATA_OPTIONS
),
):
"""Creates PyTreeCheckpointHandler.
Expand All @@ -485,9 +489,11 @@ def __init__(
type_handler_registry: a type_handlers.TypeHandlerRegistry. If not
specified, the global type handler registry will be used.
handler_impl: Allows overriding the internal implementation.
pytree_metadata_options: `PyTreeMetadataOptions` to manage metadata.
"""
self._aggregate_handler = MsgpackHandler(
primary_host=multiprocessing_options.primary_host
primary_host=multiprocessing_options.primary_host,
pytree_metadata_options=pytree_metadata_options,
)
if aggregate_filename is None:
aggregate_filename = _CHECKPOINT_FILE
Expand All @@ -505,7 +511,9 @@ def __init__(
use_zarr3=use_zarr3,
multiprocessing_options=multiprocessing_options,
type_handler_registry=type_handler_registry,
pytree_metadata_options=pytree_metadata_options,
)
self._pytree_metadata_options = pytree_metadata_options

async def async_save(
self,
Expand Down Expand Up @@ -784,6 +792,7 @@ class TrainState:
self._handler_impl.get_param_names(structure),
transforms,
restore_args,
self._pytree_metadata_options,
transforms_default_to_original=transforms_default_to_original,
use_zarr3=use_zarr3_metadata
if use_zarr3_metadata is not None
Expand Down
48 changes: 48 additions & 0 deletions checkpoint/orbax/checkpoint/_src/metadata/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,54 @@
KeyPath = tuple[KeyEntry, ...]


@dataclasses.dataclass(kw_only=True)
class PyTreeMetadataOptions:
"""Options for managing PyTree metadata.
Attributes:
support_rich_types: If True, supports NamedTuple and Tuple types in the
metadata.
"""

support_rich_types: bool


PYTREE_METADATA_OPTIONS = PyTreeMetadataOptions(support_rich_types=False)


def serialize_tree(
tree: PyTree, pytree_metadata_options: PyTreeMetadataOptions
) -> PyTree:
"""Transforms a PyTree to a serializable format.
IMPORTANT: If `pytree_metadata_options.support_rich_types` is false, the
returned tree replaces tuple container nodes with list nodes.
IMPORTANT: If `pytree_metadata_options.support_rich_types` is false, the
returned tree replaces NamedTuple container nodes with dict
nodes.
If `pytree_metadata_options.support_rich_types` is true, then the returned
tree is the same as the input tree retaining empty nodes as leafs.
Args:
tree: The tree to serialize.
pytree_metadata_options: `PyTreeMetadataOptions` for managing PyTree
metadata.
Returns:
The serialized PyTree.
"""
if pytree_metadata_options.support_rich_types:
return jax.tree_util.tree_map(
lambda x: x,
tree,
is_leaf=tree_utils.is_empty_or_leaf,
)

return tree_utils.serialize_tree(tree, keep_empty_nodes=True)


class KeyType(enum.Enum):
"""Enum representing PyTree key type."""

Expand Down
18 changes: 17 additions & 1 deletion checkpoint/orbax/checkpoint/_src/tree/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,17 @@ def _extend_list(ls, idx, nextvalue):
def from_flattened_with_keypath(
flat_with_keys: list[tuple[tuple[Any, ...], Any]],
) -> PyTree:
"""Reconstructs a tree given the a flat dict with keypaths."""
"""Returns a tree for the given list of (KeyPath, value) pairs.
IMPORTANT: The returned tree replaces tuple container nodes with list nodes,
even though the input KeyPath had originated from a tuple.
IMPORTANT: The returned tree replaces NamedTuple container nodes with dict
nodes, even though the input KeyPath had originated from a NamedTuple.
Args:
flat_with_keys: A list of pair of Keypath and values.
"""
if not flat_with_keys:
raise ValueError(
'Unable to uniquely reconstruct tree from empty flattened list '
Expand Down Expand Up @@ -152,6 +162,12 @@ def from_flattened_with_keypath(
def serialize_tree(tree: PyTree, keep_empty_nodes: bool = False) -> PyTree:
"""Transforms a PyTree to a serializable format.
IMPORTANT: The returned tree replaces tuple container nodes with list nodes.
IMPORTANT: The returned tree replaces NamedTuple container nodes with dict
nodes.
Args:
tree: The tree to serialize, if tree is empty and keep_empty_nodes is False,
an error is raised as there is no valid representation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,47 +30,13 @@ class EmptyNamedTuple(NamedTuple):
pass


class UtilsTest(parameterized.TestCase):
# TODO: b/365169723 - Add tests: PyTreeMetadataOptions.support_rich_types=True.
class SerializeTreeTest(parameterized.TestCase):

def setUp(self):
super().setUp()
self.directory = epath.Path(
self.create_tempdir(name='checkpointing_test').full_path
)

@parameterized.parameters(
({'a': 1, 'b': {'c': {}, 'd': 2}}, {('a',): 1, ('b', 'd'): 2}),
({'x': ['foo', 'bar']}, {('x', '0'): 'foo', ('x', '1'): 'bar'}),
)
def test_to_flat_dict(self, tree, expected):
self.assertDictEqual(expected, tree_utils.to_flat_dict(tree))

@parameterized.parameters(
({'a': 1, 'b': {'d': 2}}, {('a',): 1, ('b', 'd'): 2}),
({'x': ['foo', 'bar']}, {('x', '0'): 'foo', ('x', '1'): 'bar'}),
({'a': 1, 'b': 2}, {('b',): 2, ('a',): 1}),
)
def test_from_flat_dict(self, expected, flat_dict):
empty = jax.tree.map(lambda _: 0, expected)
self.assertDictEqual(
expected, tree_utils.from_flat_dict(flat_dict, target=empty)
)

@parameterized.parameters(
({'a': 1, 'b': {'d': 2}}, {('a',): 1, ('b', 'd'): 2}),
({'a': 1, 'b': 2}, {('b',): 2, ('a',): 1}),
)
def test_from_flat_dict_without_target(self, expected, flat_dict):
self.assertDictEqual(expected, tree_utils.from_flat_dict(flat_dict))

@parameterized.parameters(
({'a': 1, 'b': {'d': 2}}, {'a': 1, 'b/d': 2}),
({'a': 1, 'b': 2}, {'b': 2, 'a': 1}),
({'a': {'b': {'c': 1}}}, {'a/b/c': 1}),
)
def test_from_flat_dict_with_sep(self, expected, flat_dict):
self.assertDictEqual(
expected, tree_utils.from_flat_dict(flat_dict, sep='/')
self.create_tempdir(name='serialize_tree_test').full_path
)

def test_serialize(self):
Expand Down Expand Up @@ -164,6 +130,50 @@ class Foo(flax.struct.PyTreeNode):
}
self.assertDictEqual(expected, serialized)


class UtilsTest(parameterized.TestCase):

def setUp(self):
super().setUp()
self.directory = epath.Path(
self.create_tempdir(name='checkpointing_test').full_path
)

@parameterized.parameters(
({'a': 1, 'b': {'c': {}, 'd': 2}}, {('a',): 1, ('b', 'd'): 2}),
({'x': ['foo', 'bar']}, {('x', '0'): 'foo', ('x', '1'): 'bar'}),
)
def test_to_flat_dict(self, tree, expected):
self.assertDictEqual(expected, tree_utils.to_flat_dict(tree))

@parameterized.parameters(
({'a': 1, 'b': {'d': 2}}, {('a',): 1, ('b', 'd'): 2}),
({'x': ['foo', 'bar']}, {('x', '0'): 'foo', ('x', '1'): 'bar'}),
({'a': 1, 'b': 2}, {('b',): 2, ('a',): 1}),
)
def test_from_flat_dict(self, expected, flat_dict):
empty = jax.tree.map(lambda _: 0, expected)
self.assertDictEqual(
expected, tree_utils.from_flat_dict(flat_dict, target=empty)
)

@parameterized.parameters(
({'a': 1, 'b': {'d': 2}}, {('a',): 1, ('b', 'd'): 2}),
({'a': 1, 'b': 2}, {('b',): 2, ('a',): 1}),
)
def test_from_flat_dict_without_target(self, expected, flat_dict):
self.assertDictEqual(expected, tree_utils.from_flat_dict(flat_dict))

@parameterized.parameters(
({'a': 1, 'b': {'d': 2}}, {'a': 1, 'b/d': 2}),
({'a': 1, 'b': 2}, {'b': 2, 'a': 1}),
({'a': {'b': {'c': 1}}}, {'a/b/c': 1}),
)
def test_from_flat_dict_with_sep(self, expected, flat_dict):
self.assertDictEqual(
expected, tree_utils.from_flat_dict(flat_dict, sep='/')
)

@parameterized.parameters(
(1, True),
(dict(), True),
Expand Down
26 changes: 20 additions & 6 deletions checkpoint/orbax/checkpoint/aggregate_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from orbax.checkpoint import future as orbax_future
from orbax.checkpoint import msgpack_utils
from orbax.checkpoint import utils
from orbax.checkpoint._src.tree import utils as tree_utils
from orbax.checkpoint._src.metadata import tree as tree_metadata

PyTree = Any

Expand Down Expand Up @@ -58,12 +58,18 @@ def close(self):


class MsgpackHandler(AggregateHandler):
"""An implementation of AggregateHandler that uses msgpack to store the tree.
"""

def __init__(self, primary_host: Optional[int] = 0):
"""An implementation of AggregateHandler that uses msgpack to store the tree."""

def __init__(
self,
primary_host: Optional[int] = 0,
pytree_metadata_options: tree_metadata.PyTreeMetadataOptions = (
tree_metadata.PYTREE_METADATA_OPTIONS
),
):
self._executor = futures.ThreadPoolExecutor(max_workers=1)
self._primary_host = primary_host
self._pytree_metadata_options = pytree_metadata_options

async def serialize(
self, path: epath.Path, item: PyTree
Expand All @@ -72,7 +78,15 @@ async def serialize(

def _serialize_fn(x):
if utils.is_primary_host(self._primary_host):
serializable_dict = tree_utils.serialize_tree(x, keep_empty_nodes=True)
if self._pytree_metadata_options.support_rich_types:
raise NotImplementedError(
'Orbax does not support rich typed metadata in legacy msgpack'
' checkpoint format. Please set'
' PyTreeMetadataOptions.support_rich_types to False.'
)
serializable_dict = tree_metadata.serialize_tree(
x, self._pytree_metadata_options
)
msgpack = msgpack_utils.msgpack_serialize(serializable_dict)
# Explicit "copy" phase is not needed because msgpack only contains
# basic types and numpy arrays.
Expand Down
1 change: 1 addition & 0 deletions checkpoint/orbax/checkpoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@

deserialize_tree = tree_utils.deserialize_tree
from_flat_dict = tree_utils.from_flat_dict
# TODO: b/365169723 - Remove public access to this function.
from_flattened_with_keypath = tree_utils.from_flattened_with_keypath
serialize_tree = tree_utils.serialize_tree
to_flat_dict = tree_utils.to_flat_dict
Expand Down

0 comments on commit 0e737f0

Please sign in to comment.