Skip to content

Commit

Permalink
Move multihost package to _src/multihost. Important symbols are still…
Browse files Browse the repository at this point in the history
… exported publicly.

PiperOrigin-RevId: 687313527
  • Loading branch information
cpgaffney1 authored and Orbax Authors committed Oct 18, 2024
1 parent 8cf8779 commit 178a654
Show file tree
Hide file tree
Showing 27 changed files with 119 additions and 122 deletions.
1 change: 1 addition & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ customize this on a per-array level.

### Changed
- Rename `CheckpointManager._single_item` to `CheckpointManager._default_item`.
- Move `multihost` implementations to `_src`. Commonly used symbols are still exported in the same way.
- Use `Fragments` for serialization.
- Set `AsyncOptions.timeout_secs` default value to 10 minutes.

Expand Down
20 changes: 8 additions & 12 deletions checkpoint/orbax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""Defines exported symbols for the namespace package `orbax.checkpoint`."""

# pylint: disable=g-importing-member

import contextlib
import functools

Expand All @@ -23,7 +25,6 @@
from orbax.checkpoint import logging
from orbax.checkpoint import metadata
from orbax.checkpoint import msgpack_utils
from orbax.checkpoint import multihost
from orbax.checkpoint import options
from orbax.checkpoint import path
from orbax.checkpoint import serialization
Expand All @@ -32,33 +33,28 @@
from orbax.checkpoint import tree
from orbax.checkpoint import type_handlers
from orbax.checkpoint import utils
from orbax.checkpoint.path import step

# pylint: disable=g-importing-member, g-bad-import-order
from orbax.checkpoint import version
from orbax.checkpoint._src.handlers.pytree_checkpoint_handler import RestoreArgs
from orbax.checkpoint._src.handlers.pytree_checkpoint_handler import SaveArgs
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint.abstract_checkpoint_manager import AbstractCheckpointManager
from orbax.checkpoint.abstract_checkpointer import AbstractCheckpointer

from orbax.checkpoint.async_checkpointer import AsyncCheckpointer
from orbax.checkpoint.checkpoint_manager import AsyncOptions
from orbax.checkpoint.checkpoint_manager import CheckpointManager
from orbax.checkpoint.checkpoint_manager import CheckpointManagerOptions
from orbax.checkpoint.checkpointer import Checkpointer
from orbax.checkpoint.future import Future
from orbax.checkpoint.handlers import *
from orbax.checkpoint.path import step
from orbax.checkpoint.pytree_checkpointer import PyTreeCheckpointer
from orbax.checkpoint.standard_checkpointer import StandardCheckpointer
from orbax.checkpoint.transform_utils import apply_transformations
from orbax.checkpoint.transform_utils import merge_trees
from orbax.checkpoint.transform_utils import RestoreTransform
from orbax.checkpoint.transform_utils import Transform
from orbax.checkpoint._src.handlers.pytree_checkpoint_handler import RestoreArgs
from orbax.checkpoint._src.handlers.pytree_checkpoint_handler import SaveArgs

from orbax.checkpoint.handlers import *
# pylint: enable=g-importing-member, g-bad-import-order

# pylint: disable=g-bad-import-order, g-import-not-at-top
from orbax.checkpoint import version
# A new PyPI release will be pushed everytime `__version__` is increased.
__version__ = version.__version__
del version
# pylint: enable=g-bad-import-order, g-import-not-at-top
20 changes: 8 additions & 12 deletions checkpoint/orbax/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""Defines exported symbols for the namespace package `orbax.checkpoint`."""

# pylint: disable=g-importing-member

import contextlib
import functools

Expand All @@ -23,7 +25,6 @@
from orbax.checkpoint import logging
from orbax.checkpoint import metadata
from orbax.checkpoint import msgpack_utils
from orbax.checkpoint import multihost
from orbax.checkpoint import options
from orbax.checkpoint import path
from orbax.checkpoint import serialization
Expand All @@ -32,33 +33,28 @@
from orbax.checkpoint import tree
from orbax.checkpoint import type_handlers
from orbax.checkpoint import utils
from orbax.checkpoint.path import step

# pylint: disable=g-importing-member, g-bad-import-order
from orbax.checkpoint import version
from orbax.checkpoint._src.handlers.pytree_checkpoint_handler import RestoreArgs
from orbax.checkpoint._src.handlers.pytree_checkpoint_handler import SaveArgs
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint.abstract_checkpoint_manager import AbstractCheckpointManager
from orbax.checkpoint.abstract_checkpointer import AbstractCheckpointer

from orbax.checkpoint.async_checkpointer import AsyncCheckpointer
from orbax.checkpoint.checkpoint_manager import AsyncOptions
from orbax.checkpoint.checkpoint_manager import CheckpointManager
from orbax.checkpoint.checkpoint_manager import CheckpointManagerOptions
from orbax.checkpoint.checkpointer import Checkpointer
from orbax.checkpoint.future import Future
from orbax.checkpoint.handlers import *
from orbax.checkpoint.path import step
from orbax.checkpoint.pytree_checkpointer import PyTreeCheckpointer
from orbax.checkpoint.standard_checkpointer import StandardCheckpointer
from orbax.checkpoint.transform_utils import apply_transformations
from orbax.checkpoint.transform_utils import merge_trees
from orbax.checkpoint.transform_utils import RestoreTransform
from orbax.checkpoint.transform_utils import Transform
from orbax.checkpoint._src.handlers.pytree_checkpoint_handler import RestoreArgs
from orbax.checkpoint._src.handlers.pytree_checkpoint_handler import SaveArgs

from orbax.checkpoint.handlers import *
# pylint: enable=g-importing-member, g-bad-import-order

# pylint: disable=g-bad-import-order, g-import-not-at-top
from orbax.checkpoint import version
# A new PyPI release will be pushed everytime `__version__` is increased.
__version__ = version.__version__
del version
# pylint: enable=g-bad-import-order, g-import-not-at-top
6 changes: 6 additions & 0 deletions checkpoint/orbax/checkpoint/_src/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# `_src` package

This package contains (or is intended to contain in the future) the majority of
actual orbax-checkpoint implementations. Code from this directory should not be
directly relied upon by outside users. Instead, depend on symbols exported by
`orbax.checkpoint` or any other subpackages.
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@
import jax
from orbax.checkpoint import checkpoint_args
from orbax.checkpoint import future
from orbax.checkpoint import multihost
from orbax.checkpoint import options as options_lib
from orbax.checkpoint import serialization
from orbax.checkpoint import tree as tree_utils
from orbax.checkpoint import type_handlers
from orbax.checkpoint import utils
from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.handlers import async_checkpoint_handler
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint.metadata import tree as tree_metadata
from orbax.checkpoint.path import format_utils
import tensorstore as ts
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for CompositeHandler."""

from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
from etils import epath
from jax import numpy as jnp
from orbax.checkpoint import args as args_lib
from orbax.checkpoint import multihost
from orbax.checkpoint import test_utils
from orbax.checkpoint._src.handlers import checkpoint_handler
from orbax.checkpoint._src.handlers import composite_checkpoint_handler
from orbax.checkpoint._src.handlers import handler_registration
from orbax.checkpoint._src.handlers import json_checkpoint_handler
from orbax.checkpoint._src.handlers import proto_checkpoint_handler
from orbax.checkpoint._src.handlers import standard_checkpoint_handler
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint.metadata import value as value_metadata
from orbax.checkpoint.path import step

Expand Down Expand Up @@ -591,9 +589,7 @@ def test_no_restore_args_handler_registry(self):
CompositeArgs(),
)
with self.assertRaisesRegex(KeyError, 'could not be restored'):
restore_handler_without_registry.restore(
self.directory
)
restore_handler_without_registry.restore(self.directory)

restore_handler_with_registry = CompositeCheckpointHandler(
handler_registry=handler_registry,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
from jax import numpy as jnp
import numpy as np
import optax
from orbax.checkpoint import multihost
from orbax.checkpoint import test_utils
from orbax.checkpoint import type_handlers
from orbax.checkpoint import utils
from orbax.checkpoint._src.handlers import standard_checkpoint_handler
from orbax.checkpoint._src.multihost import multihost

PyTree = Any
SaveArgs = type_handlers.SaveArgs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Orbax utils related to multihost functionality."""
"""Orbax utils related to multihost_utils functionality."""

import threading
import time
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Multislice utils."""
"""Multislice utilities."""

import functools
from typing import Any, Optional, Set, Tuple, Union
Expand All @@ -21,7 +21,7 @@
import jax
from jax import numpy as jnp
import numpy as np
from orbax.checkpoint.multihost import utils
from orbax.checkpoint._src.multihost import multihost

PyTree = Any

Expand All @@ -42,15 +42,15 @@ def process_slice_id(
device_slice = slice_devices(
global_mesh, replica_id=slice_id, replica_axis_index=replica_axis_index
)
if process_index in utils.unique_processes_from_devices(device_slice):
if process_index in multihost.unique_processes_from_devices(device_slice):
return slice_id
return -1


def _process_in_device_slice(
process_index: int, device_slice: np.ndarray
) -> bool:
return process_index in utils.unique_processes_from_devices(device_slice)
return process_index in multihost.unique_processes_from_devices(device_slice)


def slice_devices(
Expand All @@ -72,7 +72,7 @@ def local_slice_devices(
"""Get devices in the host-local slice."""
for replica_id in range(global_mesh.devices.shape[replica_axis_index]):
if in_slice(
utils.process_index(),
multihost.process_index(),
global_mesh,
replica_id=replica_id,
replica_axis_index=replica_axis_index,
Expand All @@ -83,7 +83,7 @@ def local_slice_devices(
replica_axis_index=replica_axis_index,
)
raise ValueError(
f'process_index {utils.process_index()} does not exist in provided'
f'process_index {multihost.process_index()} does not exist in provided'
' `global_mesh`'
)

Expand All @@ -100,7 +100,7 @@ def primary_process_in_slice(
replica_axis_index=replica_axis_index,
replica_id=replica_id,
)
processes = utils.unique_processes_from_devices(device_slice)
processes = multihost.unique_processes_from_devices(device_slice)
return next(iter(processes))


Expand Down
5 changes: 3 additions & 2 deletions checkpoint/orbax/checkpoint/async_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@
from orbax.checkpoint import checkpoint_args
from orbax.checkpoint import checkpointer
from orbax.checkpoint import future as future_lib
from orbax.checkpoint import multihost
from orbax.checkpoint import options as options_lib
from orbax.checkpoint import utils
from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.handlers import async_checkpoint_handler
from orbax.checkpoint._src.multihost import counters
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint.metadata import checkpoint
from orbax.checkpoint.path import async_utils
from orbax.checkpoint.path import atomicity
Expand Down Expand Up @@ -316,7 +317,7 @@ def __init__(
)

def _unique_operation_id(self) -> str:
return multihost.counters.async_save_counter()
return counters.async_save_counter()

async def _save(
self, directory: epath.PathLike, *args, force: bool = False, **kwargs
Expand Down
3 changes: 2 additions & 1 deletion checkpoint/orbax/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,15 @@
from orbax.checkpoint import async_checkpointer
from orbax.checkpoint import checkpoint_args
from orbax.checkpoint import checkpointer as checkpointer_lib
from orbax.checkpoint import multihost
from orbax.checkpoint import options as options_lib
from orbax.checkpoint import utils
from orbax.checkpoint._src.handlers import checkpoint_handler
from orbax.checkpoint._src.handlers import composite_checkpoint_handler
from orbax.checkpoint._src.handlers import handler_registration
from orbax.checkpoint._src.handlers import json_checkpoint_handler
from orbax.checkpoint._src.handlers import proto_checkpoint_handler
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.path import utils as path_utils
from orbax.checkpoint.logging import abstract_logger
from orbax.checkpoint.logging import standard_logger
from orbax.checkpoint.logging import step_statistics
Expand Down
2 changes: 1 addition & 1 deletion checkpoint/orbax/checkpoint/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
from etils import epath
import jax
import numpy as np
from orbax.checkpoint import multihost
from orbax.checkpoint import type_handlers
from orbax.checkpoint import utils
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint.metadata import value as value_metadata
from orbax.checkpoint.path import step as step_lib

Expand Down
2 changes: 1 addition & 1 deletion checkpoint/orbax/checkpoint/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
import jax
from orbax.checkpoint import abstract_checkpointer
from orbax.checkpoint import checkpoint_args
from orbax.checkpoint import multihost
from orbax.checkpoint import options as options_lib
from orbax.checkpoint import utils
from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.handlers import checkpoint_handler
from orbax.checkpoint._src.handlers import composite_checkpoint_handler
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint.metadata import checkpoint
from orbax.checkpoint.path import atomicity
from typing_extensions import Self # for Python version < 3.11
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,16 @@
from orbax.checkpoint import abstract_checkpoint_manager
from orbax.checkpoint import args as args_lib
from orbax.checkpoint import checkpoint_manager
from orbax.checkpoint import multihost
from orbax.checkpoint import type_handlers
from orbax.checkpoint import utils
from orbax.checkpoint._src.handlers import pytree_checkpoint_handler
from orbax.checkpoint._src.multihost import counters
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.multihost import multislice
from orbax.checkpoint.experimental.emergency import multihost as emergency_multihost
from orbax.checkpoint.logging import abstract_logger
from orbax.checkpoint.logging import standard_logger
from orbax.checkpoint.logging import step_statistics
from orbax.checkpoint.multihost import multislice
from orbax.checkpoint.path import step as step_lib
from typing_extensions import Self # for Python version < 3.11

Expand All @@ -60,7 +61,7 @@
CheckpointHandler = checkpoint_manager.CheckpointHandler
P = jax.sharding.PartitionSpec
PyTreeCheckpointHandler = pytree_checkpoint_handler.PyTreeCheckpointHandler
unique_barrier_key = multihost.utils._unique_barrier_key # pylint: disable=protected-access
unique_barrier_key = multihost._unique_barrier_key # pylint: disable=protected-access

_PROCESS_METADATA_FOLDER = 'process_metadata'
_PROCESS_METADATA_FILE_NAME = 'process_metadata.json'
Expand Down Expand Up @@ -275,11 +276,11 @@ class _BarrierIdentifier(enum.Enum):

def get_counter(self) -> str:
if self.name == self.GLOBAL_MAX.name:
return multihost.counters.global_max_broadcast_counter()
return counters.global_max_broadcast_counter()
elif self.name == self.LOCAL_ALL_STEPS.name:
return multihost.counters.local_all_steps_broadcast_counter()
return counters.local_all_steps_broadcast_counter()
elif self.name == self.FIND_COMPLETE_SLICE.name:
return multihost.counters.find_complete_slice_broadcast_counter()
return counters.find_complete_slice_broadcast_counter()
else:
raise ValueError(f'Unknown barrier identifier: {self.name}')

Expand Down Expand Up @@ -361,7 +362,7 @@ def _process_local_to_global(
barrier_name = (
f'{barrier_id.name}_{slice_id}' if slice_id else barrier_id.name
)
client = multihost.utils._get_jax_distributed_client() # pylint: disable=protected-access
client = multihost._get_jax_distributed_client() # pylint: disable=protected-access
broadcast_dir_key = f'broadcast_{barrier_name}/{barrier_id.get_counter()}/'
broadcast_dir_key = unique_barrier_key(broadcast_dir_key) + '/'
broadcast_key = broadcast_dir_key + str(multihost.process_index())
Expand Down Expand Up @@ -1054,10 +1055,8 @@ def _get_single_slice_sharding(
restore_single_slice_shardings,
self._abstract_state,
)
restore_directory = (
self._local_checkpoint_manager._get_read_step_directory( # pylint: disable=protected-access
step, epath.Path(directory or self._local_directory)
)
restore_directory = self._local_checkpoint_manager._get_read_step_directory( # pylint: disable=protected-access
step, epath.Path(directory or self._local_directory)
)
step_stats.directory = str(restore_directory)

Expand Down
Loading

0 comments on commit 178a654

Please sign in to comment.