Skip to content

Commit

Permalink
Remove OCDBT coordinator-related code.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 627823270
  • Loading branch information
cpgaffney1 authored and Orbax Authors committed Apr 24, 2024
1 parent e8b144e commit de1a0dc
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 106 deletions.
1 change: 1 addition & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Removed
- `ocdbt_merge` option and unused `restore_with_serialized_types` option from
`PyTreeCheckpointHandler`.
- OCDBT coordinator code. These functions are no longer needed.

## [0.5.10] - 2024-04-22

Expand Down
4 changes: 1 addition & 3 deletions checkpoint/orbax/checkpoint/array_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def __init__(self, checkpoint_name: Optional[str] = None):
checkpoint_name = 'checkpoint'
self._checkpoint_name = checkpoint_name
self._aggregate_handler = aggregate_handlers.MsgpackHandler()
type_handlers.start_coordinator_server_and_create_context()

def _is_supported_type(self, item: ArrayType) -> bool:
return isinstance(item, (np.ndarray, jax.Array)) or utils.is_scalar(item)
Expand Down Expand Up @@ -93,8 +92,7 @@ async def async_save(
name=self._checkpoint_name,
path=directory / self._checkpoint_name,
parent_dir=directory,
is_ocdbt_checkpoint=True,
ocdbt_merge=False,
is_ocdbt_checkpoint=False,
)
type_handler = type_handlers.get_type_handler(type(item))
futures = await type_handler.serialize([item], [info], args=[save_args])
Expand Down
109 changes: 7 additions & 102 deletions checkpoint/orbax/checkpoint/type_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from orbax.checkpoint import value_metadata
import tensorstore as ts


Scalar = Union[int, float, np.number]
Metadata = value_metadata.Metadata
NamedSharding = jax.sharding.NamedSharding
Expand All @@ -47,9 +46,7 @@
StringMetadata = value_metadata.StringMetadata
ShardingMetadata = sharding_metadata.ShardingMetadata
_OCDBT_MANIFEST_FILE = 'manifest.ocdbt'
_COORDINATOR_SETUP_TIMEOUT_SECS = 300
_OCDBT_TS_CONTEXT = None
_OCDBT_COORDINATOR_SERVER = None
_DEFAULT_OCDBT_TS_CONTEXT = ts.Context(
{
# Provide cache pool for B-tree nodes to avoid repeated reads.
Expand All @@ -73,84 +70,6 @@
ZARR_VER3 = 'zarr3'


def _get_coordinator_address_without_port(coordinator_address: str) -> str:
"""Returns JAX coordinator address stripped of port number."""
return coordinator_address.split(':')[0]


def create_coordinator_server_and_context() -> Tuple[None, None]:
# TODO(b/293331479) remove this once OCDBT is enabled by default.
warnings.warn('This function has been deprecated. Do not use.')
return (None, None)


def start_coordinator_server_and_create_context() -> None:
"""Start a OCDBT coordinator and create a Tensorstore context.
This function is only for Orbax internal use.
The following function starts a coordinator_server and update type handlers
with enable_ocdbt() defined.
The context and server will be stored as global variables in _OCDBT_TS_CONTEXT
and _OCDBT_COORDINATOR_SERVER. They will be preserved for the life of the
program. Succeeding calls to this function will not try to start the
coordinator server again.
For testing purpose, if one needs to restart the coordinator server, set
_OCDBT_TS_CONTEXT and _OCDBT_COORDINATOR_SERVER to None and call this function
again.
Returns:
None
"""
global _OCDBT_TS_CONTEXT, _OCDBT_COORDINATOR_SERVER

if _OCDBT_TS_CONTEXT is not None:
# OCDBT ts_context is already set, return
return

ts_context = {
# Provide cache pool for B-tree nodes to avoid repeated reads.
# 100MB limit.
'cache_pool#ocdbt': {'total_bytes_limit': 100000000},
}

jax_global_state = jax._src.distributed.global_state # pylint: disable=protected-access
if (
jax_global_state.coordinator_address
and jax_global_state.num_processes > 1
):
ocdbt_address = _get_coordinator_address_without_port(
jax_global_state.coordinator_address
)

if jax_global_state.process_id == 0:
bind_address = f'{ocdbt_address}:0'
_OCDBT_COORDINATOR_SERVER = ts.ocdbt.DistributedCoordinatorServer({
'bind_addresses': [bind_address],
})
ocdbt_coordinator = f'{ocdbt_address}:{_OCDBT_COORDINATOR_SERVER.port}'
logging.info(
'Started OCDBT DistributedCoordinatorServer at: %s', ocdbt_coordinator
)
jax_global_state.client.key_value_set(
'ocdbt_coordinator', ocdbt_coordinator
)

ocdbt_address = jax_global_state.client.blocking_key_value_get(
'ocdbt_coordinator', _COORDINATOR_SETUP_TIMEOUT_SECS * 1000
)

# add ocdbt_coordinator spec into ts_context
ts_context['ocdbt_coordinator'] = {
'address': ocdbt_address,
}

_OCDBT_TS_CONTEXT = ts.Context(ts_context, parent=serialization.TS_CONTEXT)
logging.info('OCDBT is initialized successfully.')


async def _assert_parameter_files_exist(
param_dir: epath.Path, metadata_key: Optional[str], use_zarr3: bool = False
):
Expand Down Expand Up @@ -254,7 +173,6 @@ class ParamInfo:
skip_deserialize: Optional[bool] = None
byte_limiter: Optional[serialization._LimitInFlightBytes] = None # pylint: disable=protected-access
is_ocdbt_checkpoint: Optional[bool] = None
ocdbt_merge: Optional[bool] = True
use_zarr3: Optional[bool] = False
ocdbt_target_data_file_size: Optional[int] = None

Expand Down Expand Up @@ -800,28 +718,19 @@ def get_tensorstore_spec(

def get_process_index_for_subdir(
use_ocdbt: bool,
ocdbt_merge: bool,
) -> Optional[int]:
"""If OCDBT + merge feature is in use, returns a process index."""
if use_ocdbt and ocdbt_merge:
if use_ocdbt:
return jax.process_index()
else:
return None


def get_ts_context(use_ocdbt: bool, ocdbt_merge: bool = True) -> ts.Context:
def get_ts_context(use_ocdbt: bool) -> ts.Context:
"""Returns a shared global TensorStore Context instance to use."""
if not use_ocdbt:
return serialization.TS_CONTEXT

if ocdbt_merge:
return _DEFAULT_OCDBT_TS_CONTEXT
if _OCDBT_TS_CONTEXT is None:
raise ValueError(
'Coordinator-based TensorStore context should be configured'
' if OCDBT merging is not enabled.'
)
return _OCDBT_TS_CONTEXT
return _DEFAULT_OCDBT_TS_CONTEXT


def _get_cast_tspec_serialize(tspec, value, args):
Expand Down Expand Up @@ -996,9 +905,7 @@ async def serialize(
info,
value,
use_ocdbt=info.is_ocdbt_checkpoint,
process_index=get_process_index_for_subdir(
info.is_ocdbt_checkpoint, info.ocdbt_merge
),
process_index=get_process_index_for_subdir(info.is_ocdbt_checkpoint),
arg=arg,
)
tspec = _get_cast_tspec_serialize(tspec, value, arg)
Expand All @@ -1007,7 +914,7 @@ async def serialize(
logging.debug('infos = %s', info)
logging.debug('args = %s', arg)
if jax.process_index() == 0:
ts_context = get_ts_context(info.is_ocdbt_checkpoint, info.ocdbt_merge)
ts_context = get_ts_context(info.is_ocdbt_checkpoint)
# Open once to create metadata and allow the operation to happen
# asynchronously.
open_future = ts.open(
Expand Down Expand Up @@ -1354,13 +1261,11 @@ async def serialize(
info,
value,
use_ocdbt=info.is_ocdbt_checkpoint,
process_index=get_process_index_for_subdir(
info.is_ocdbt_checkpoint, info.ocdbt_merge
),
process_index=get_process_index_for_subdir(info.is_ocdbt_checkpoint),
arg=arg,
)
tspec = _get_cast_tspec_serialize(tspec, value, arg)
ts_context = get_ts_context(info.is_ocdbt_checkpoint, info.ocdbt_merge)
ts_context = get_ts_context(info.is_ocdbt_checkpoint)
if self._replica_id is None:
replica_id = value.addressable_shards[0].replica_id
else:
Expand Down
1 change: 0 additions & 1 deletion docs/api_reference/checkpoint.type_handlers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ StringHandler

OCDBT functions
------------------------
.. autofunction:: start_coordinator_server_and_create_context
.. autofunction:: is_ocdbt_checkpoint

TypeHandler registry
Expand Down

0 comments on commit de1a0dc

Please sign in to comment.