Skip to content

Commit

Permalink
implement all_steps function for experimental emergency CheckpointM…
Browse files Browse the repository at this point in the history
…anager. At the same time, only create sub CheckpointManagers on appropriate slices.

PiperOrigin-RevId: 626585449
  • Loading branch information
Orbax Authors committed Apr 20, 2024
1 parent 4b50229 commit f06df1f
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 55 deletions.
234 changes: 180 additions & 54 deletions checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@

import collections
import dataclasses
from typing import Any, Optional, Sequence
from typing import Any, Optional, Sequence, Set
from etils import epath
import jax
from jax.experimental import multihost_utils
import jax.numpy as jnp
import numpy as np
from orbax.checkpoint import abstract_checkpoint_manager
from orbax.checkpoint import args as args_lib
from orbax.checkpoint import checkpoint_manager
from orbax.checkpoint import multihost
from orbax.checkpoint import utils
from orbax.checkpoint.path import step as step_lib

Expand Down Expand Up @@ -124,8 +126,13 @@ def _in_primary_slice(
return _in_slice(process_index, primary_slice)


def _local_slice_devices(global_mesh: jax.sharding.Mesh) -> np.ndarray:
for device_slice in global_mesh.devices:
def _unique_processes_from_devices(device_array: np.ndarray) -> Set[int]:
pid = np.vectorize(lambda d: d.process_index)
return set(pid(device_array).flat)


def _local_slice_devices(devices_array: np.ndarray) -> np.ndarray:
for device_slice in devices_array:
if _in_slice(jax.process_index(), device_slice):
return device_slice
raise ValueError(
Expand All @@ -138,13 +145,14 @@ def _pad_steps(steps, target):
return steps + [-1] * (target - len(steps))


class LocalCheckpointManager(checkpoint_manager.CheckpointManager):
class _LocalCheckpointManager(checkpoint_manager.CheckpointManager):
"""A checkpoint manager that checkpoints to local storage.
Attributes:
global_mesh: a Mesh object representing the global mesh configuration,
importantly the first axis of the global_mesh is assumed to be the
direction of device slices across which the Data Parallelism is happening.
device_array: an ndarray representing all the devices running
LocalCheckpointManager in the same global jax Mesh, importantly the first
axis of the device_array is assumed to be the direction of device slices
across which the Data Parallelism is happening.
"""

# TODO: b/330585086 - Allow configuration of global mesh describing slices.
Expand All @@ -156,7 +164,7 @@ def __init__(
# to evaluate whether arbitrary items can be a good fit for local
# checkpointing, given restore+broadcast requirements.
state_handler: CheckpointHandler,
global_mesh: jax.sharding.Mesh,
device_array: np.ndarray,
*,
options: Optional[CheckpointManagerOptions] = None,
metadata: Optional[dict[str, Any]] = None,
Expand All @@ -171,11 +179,14 @@ def __init__(
cleanup_tmp_directories=options.cleanup_tmp_directories,
async_options=options.async_options,
multiprocessing_options=checkpoint_manager.MultiprocessingOptions(
primary_host=None
primary_host=None,
active_processes=_unique_processes_from_devices(device_array),
),
# TODO: b/331426277 - remove async false after barrier is done.
enable_async_checkpointing=False,
)
self._options = local_options
self._global_mesh = global_mesh
self._device_array = device_array

super().__init__(
directory,
Expand Down Expand Up @@ -216,8 +227,7 @@ def _common_steps_global(self, steps: Sequence[int]) -> np.ndarray:
Args:
steps: a list of steps known to all hosts on a slice
"""
devices = self._global_mesh.devices
unioned_steps = self._global_list_union(steps, devices)
unioned_steps = self._global_list_union(steps, self._device_array)

return np.asarray(list(set(unioned_steps)))

Expand All @@ -236,7 +246,7 @@ def _common_steps_within_slice(self, steps: Sequence[int]) -> np.ndarray:
steps: a list of known steps on host
"""

devices = _local_slice_devices(self._global_mesh)
devices = _local_slice_devices(self._device_array)
slice_process_count = devices.size // jax.local_device_count()
unioned_steps = self._global_list_union(steps, devices)

Expand Down Expand Up @@ -318,31 +328,58 @@ def __init__(
):
# TODO: b/330585086 - Fully support options.
options = options or CheckpointManagerOptions()
self._global_mesh = global_mesh
self._device_array = global_mesh.devices

self._local_checkpoint_manager = LocalCheckpointManager(
local_directory,
local_state_handler,
global_mesh=global_mesh,
options=options,
metadata=metadata,
)
# TODO: b/330585086 - Build options for persistent CheckpointManager.
persistent_options = checkpoint_manager.CheckpointManagerOptions(
save_interval_steps=options.persistent.save_interval_steps,
max_to_keep=options.persistent.max_to_keep,
step_name_format=options.step_name_format,
create=options.create,
cleanup_tmp_directories=options.cleanup_tmp_directories,
async_options=options.async_options,
)
self._persistent_checkpoint_manager = checkpoint_manager.CheckpointManager(
persistent_directory,
options=persistent_options,
metadata=metadata,
item_handlers=persistent_state_handler,
# TODO: b/330585086 - Use the appropriate MultiprocessingOptions.
self._persistent_primary_host = global_mesh.devices[0].flat[0].process_index
self._local_primary_host = (
global_mesh.devices[1].flat[0].process_index
if global_mesh.devices.shape[0] > 1
else None
)
if self._local_primary_host is None:
raise AssertionError(
'to use this CheckpointManager, at least 3 data-parallel slices are'
' needed.'
)

self._in_primary_slice = _in_primary_slice(jax.process_index(), global_mesh)
self._persistent_max_to_keep = options.persistent.max_to_keep
self._local_max_to_keep = options.local.max_to_keep

if self._in_primary_slice:
# TODO: b/330585086 - Build options for persistent CheckpointManager.
persistent_options = checkpoint_manager.CheckpointManagerOptions(
save_interval_steps=options.persistent.save_interval_steps,
max_to_keep=self._persistent_max_to_keep,
step_name_format=options.step_name_format,
create=options.create,
cleanup_tmp_directories=options.cleanup_tmp_directories,
async_options=options.async_options,
multiprocessing_options=checkpoint_manager.MultiprocessingOptions(
primary_host=self._persistent_primary_host,
active_processes=_unique_processes_from_devices(
global_mesh.devices[0]
),
),
# TODO: b/331426277 - remove async false after barrier is done.
enable_async_checkpointing=False,
)
self._persistent_checkpoint_manager = (
checkpoint_manager.CheckpointManager(
persistent_directory,
options=persistent_options,
metadata=metadata,
item_handlers=persistent_state_handler,
)
)
else:
self._local_checkpoint_manager = _LocalCheckpointManager(
local_directory,
local_state_handler,
device_array=global_mesh.devices[1:],
options=options,
metadata=metadata,
)

@property
def directory(self) -> epath.Path:
Expand All @@ -360,8 +397,48 @@ def all_steps(self, read: bool = False) -> Sequence[int]:
Returns:
A sequence of steps (integers)
"""
# TODO: b/330585086 - Implement.
raise NotImplementedError('Implement: b/330585086.')
local_steps = [-1] * self._local_max_to_keep
persistent_steps = [-1] * self._persistent_max_to_keep
if self._in_primary_slice:
persistent_steps = list(
self._persistent_checkpoint_manager.all_steps(read)
)
if len(persistent_steps) > self._persistent_max_to_keep:
# TODO: b/330585086 - for now we assume that
# persistent_checkpoint_manager.all_steps returns an array with length
# smaller than max_to_keep
raise AssertionError(
f'persistent_step on host {jax.process_index()} exceeded'
f' `max_to_keep` {self._persistent_max_to_keep}'
)
persistent_steps = _pad_steps(
persistent_steps, self._persistent_max_to_keep
)
else:
local_steps = _pad_steps(
list(self._local_checkpoint_manager.all_steps(read)),
self._local_max_to_keep,
)

local_steps = np.asarray(
multihost.broadcast_one_to_all(
local_steps,
is_source=jax.process_index() == self._local_primary_host,
)
)

persistent_steps = np.asarray(
multihost.broadcast_one_to_all(
persistent_steps,
is_source=jax.process_index() == self._persistent_primary_host,
)
)

return [
x
for x in set(np.concatenate((local_steps, persistent_steps)))
if x != -1
]

def latest_step(self) -> Optional[int]:
"""Returns the latest step saved.
Expand All @@ -373,8 +450,16 @@ def latest_step(self) -> Optional[int]:
Returns:
A step (int) or None if no steps are present.
"""
# TODO: b/330585086 - Implement.
raise NotImplementedError('Implement: b/330585086.')
if self._in_primary_slice:
latest_step = self._persistent_checkpoint_manager.latest_step()
else:
latest_step = self._local_checkpoint_manager.latest_step()

if latest_step is None:
latest_step = -1

latest_step = self._global_max(latest_step)
return latest_step if latest_step != -1 else None

def best_step(self) -> Optional[int]:
"""Returns the best step saved, as defined by `options.best_fn`.
Expand All @@ -392,13 +477,36 @@ def best_step(self) -> Optional[int]:

def reload(self):
"""Performs disk reads to ensure internal properties are up to date."""
self._local_checkpoint_manager.reload()
self._persistent_checkpoint_manager.reload()
if self._local_primary_host:
self._persistent_checkpoint_manager.reload()
else:
self._local_checkpoint_manager.reload()

def reached_preemption(self, step: int) -> bool:
"""Returns True if a preemption sync point has been reached."""
return utils.reached_preemption(step)

def _global_max(self, value: Any) -> Any:
"""Returns the global max of a local value across all devices as a scalar."""
slice_mesh = jax.sharding.Mesh(
self._device_array.reshape(
self._device_array.size // jax.local_device_count(),
jax.local_device_count(),
),
['host', 'dev'],
)

g_arr = multihost_utils.host_local_array_to_global_array(
np.asarray([value]), slice_mesh, P('host')
)

result_arr = jax.jit(
jnp.max,
out_shardings=jax.sharding.NamedSharding(slice_mesh, P()),
)(g_arr)

return result_arr.addressable_data(0)

def should_save(self, step: int) -> bool:
"""Returns True if a checkpoint should be saved for the current step.
Expand All @@ -410,9 +518,11 @@ def should_save(self, step: int) -> bool:
Returns:
True if the checkpoint should be saved.
"""
return self._local_checkpoint_manager.should_save(
step
) or self._persistent_checkpoint_manager.should_save(step)
if self._in_primary_slice:
should_save = self._persistent_checkpoint_manager.should_save(step)
else:
should_save = self._local_checkpoint_manager.should_save(step)
return self._global_max(should_save)

def delete(self, step: int):
"""Deletes a step checkpoint."""
Expand All @@ -427,9 +537,19 @@ def save(
metrics: Optional[PyTree] = None,
force: Optional[bool] = False,
) -> bool:
return self._local_checkpoint_manager.save(
step, args=args, metrics=metrics, force=force
)
# TODO: b/330608746 - implement save op on different slices

# this code is only for testing purpose for all_steps funcion.
if self._in_primary_slice:
saved = self._persistent_checkpoint_manager.save(
step, args=args, metrics=metrics, force=force
)
else:
saved = self._local_checkpoint_manager.save(
step, args=args, metrics=metrics, force=force
)

return self._global_max(saved)

def restore(
self,
Expand Down Expand Up @@ -468,18 +588,24 @@ def wait_until_finished(self):
If some checkpointers are of type AsyncCheckpointer, however, this method
will wait until each of these checkpointers is finished.
"""
self._local_checkpoint_manager.wait_until_finished()
self._persistent_checkpoint_manager.wait_until_finished()
if self._in_primary_slice:
self._persistent_checkpoint_manager.wait_until_finished()
else:
self._local_checkpoint_manager.wait_until_finished()

def check_for_errors(self):
"""Checks for any outstanding errors in completed asynchronous save operations.
Delegates to underlying Checkpointer.
"""
self._local_checkpoint_manager.check_for_errors()
self._persistent_checkpoint_manager.check_for_errors()
if self._in_primary_slice:
self._local_checkpoint_manager.check_for_errors()
else:
self._persistent_checkpoint_manager.check_for_errors()

def close(self):
"""Waits for outstanding operations to finish and closes Checkpointers."""
self._local_checkpoint_manager.close()
self._persistent_checkpoint_manager.close()
if self._in_primary_slice:
self._local_checkpoint_manager.close()
else:
self._persistent_checkpoint_manager.close()
2 changes: 1 addition & 1 deletion checkpoint/orbax/checkpoint/multihost/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import functools
import time
from typing import Any, Set, Optional
from typing import Any, Optional, Set
import zlib
from absl import logging
import jax
Expand Down

0 comments on commit f06df1f

Please sign in to comment.