Skip to content

Commit

Permalink
[replica-parallel] Add replica slices concept
Browse files Browse the repository at this point in the history
  • Loading branch information
gspschmid committed Nov 13, 2024
1 parent 68729cc commit 5f533d5
Show file tree
Hide file tree
Showing 4 changed files with 393 additions and 159 deletions.
216 changes: 216 additions & 0 deletions checkpoint/orbax/checkpoint/_src/serialization/replica_slices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# Copyright 2024 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
from typing import Callable, Optional, Sequence

from absl import logging
import jax
import jax.numpy as jnp
import math
import numpy as np
from orbax.checkpoint._src.arrays import fragments
from orbax.checkpoint._src.arrays import numpy_utils
from orbax.checkpoint._src.arrays import types
from orbax.checkpoint._src.multihost import multihost


Shape = types.Shape
Index = types.Index


@dataclasses.dataclass(frozen=True)
class ReplicaSlice:
"""
ReplicaSlice represents the part of a jax.Shard that a replica is uniquely
responsible for. A replica slice can be either on-device (backed by a slice of
a single-sharding array) or on-host (backed by a numpy ndarray).
With single-replica checkpointing the entirety of each jax.Shard is owned by
exactly one replica. (Currently the only option.)
"""

replica_id: int
index: Index
data: jax.Array | np.ndarray

@property
def is_on_host(self):
return isinstance(self.data, np.ndarray)


@dataclasses.dataclass(frozen=True)
class ReplicaSlices:
"""
ReplicaSlices groups all the sliced data of one jax.Array that a replica is
uniquely responsible for. Slices are either all on-device or all on-host.
"""

global_shape: Shape
local_shape: Shape
sharding: jax.sharding.Sharding
dtype: np.dtype
is_on_host: bool
replica_slices: list[ReplicaSlice]

def __post_init__(self):
assert all(
rslice.is_on_host == self.is_on_host
for rslice in self.replica_slices
), f'inconsistent is_on_host in {self!r}'

@property
def nbytes(self) -> int:
slice_nbytes = math.prod(self.local_shape) * self.dtype.itemsize
return slice_nbytes * len(self.replica_slices)

def to_fragments(self) -> fragments.Fragments:
assert self.is_on_host
result = fragments.Fragments(
shape=self.global_shape,
dtype=self.dtype,
fragments=[
fragments.Fragment(
index=numpy_utils.resolve_slice(
rslice.index,
self.global_shape
),
value=rslice.data,
)
for rslice in self.replica_slices
],
)
if result.fragments:
fragments.validate_fragments_can_be_stacked(result)
if not result.is_degenerate():
assert self.local_shape == result.fragments[0].shape
return result


def get_replica_slices(
arr: jax.Array,
replica_id: Optional[int],
) -> ReplicaSlices:
"""Returns the replica slices a given replica is responsible for.
Does not transfer allocate or transfer any data."""
Result = tuple[list[ReplicaSlice], Shape]
shard0 = arr.addressable_shards[0]

# single-replica: a single replica saves an entire shard.
def pick_single_replica() -> Result:
# Omitting the replica id just picks the first addressable shard's replica
# id so that the process writes each of its addressable shards exactly
# once. (This is the desired behavior for local checkpointing.)
target_replica_id = replica_id or shard0.replica_id
rslices = [
ReplicaSlice(
replica_id=shard.replica_id,
index=shard.index,
data=shard.data,
)
for shard in arr.addressable_shards
if shard.replica_id == target_replica_id
]
local_shape = shard0.data.shape
return rslices, local_shape

shards_info = ', '.join(
[
f'Shard(index={shard.index}, replica_id={shard.replica_id})'
for shard in arr.addressable_shards
]
)
logging.vlog(
1,
'[process=%d] get_replica_slices: replica_id=%d, shards=[%s]',
multihost.process_index(),
replica_id,
shards_info,
)

# In order for all processes to agree on the right serialization metadata
# we want to compute the correct local shape regardless of whether there
# are any replica slices to save locally.
rslices, local_shape = pick_single_replica()
return ReplicaSlices(
global_shape=arr.shape,
local_shape=local_shape,
sharding=arr.sharding,
dtype=arr.dtype,
is_on_host=False,
replica_slices=rslices,
)


def transfer_arrays_to_host(
arrays: Sequence[jax.Array],
replica_id: Optional[int],
*,
enable_pinned_host_transfer: bool = True,
) -> Sequence[ReplicaSlices]:
"""
Transfers jax.Arrays to host memory and returns all the fragments to be
serialized by the given replica, along with local shape. Blocks until
completion.
"""

def use_pinned_host_transfer(device):
has_pinned_host = any(
m.kind == 'pinned_host' for m in device.addressable_memories()
)
return (
enable_pinned_host_transfer
and has_pinned_host
and jax._src.config.enable_memories.value # pylint: disable=protected-access
)

def async_transfer_slice(rslice: ReplicaSlice) -> tuple[ReplicaSlice, jax.Array]:
assert not rslice.is_on_host
index = rslice.index
data = rslice.data
device = data.device
# Start the asynchronous device-to-host copy
if use_pinned_host_transfer(device):
# If available, transfer to pinned host memory
data = jax.device_put(
data,
jax.sharding.SingleDeviceSharding(device, memory_kind='pinned_host'),
)
else:
data.copy_to_host_async()
return rslice, data

# Gather the replica slices to be saved for each array.
rslices_per_array = [get_replica_slices(arr, replica_id) for arr in arrays]
# Kick off transfers for all replica slices to be saved.
transfers_per_array = [
[async_transfer_slice(rslice) for rslice in rslices.replica_slices]
for rslices in rslices_per_array
]
# Wait for all the transferred data to be ready.
return [
dataclasses.replace(
rslices,
is_on_host=True,
replica_slices=[
dataclasses.replace(
rslice_on_device,
# Conversion to numpy arrays forces block_until_ready.
data=np.asarray(data),
)
for rslice_on_device, data in transfers
],
)
for rslices, transfers in zip(rslices_per_array, transfers_per_array)
]
123 changes: 123 additions & 0 deletions checkpoint/orbax/checkpoint/_src/serialization/replica_slices_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright 2024 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest
from absl.testing import parameterized
import jax
import numpy as np
from orbax.checkpoint._src.serialization import replica_slices


def is_pow_of_two(n):
while n > 1:
n, rem = divmod(n, 2)
if rem == 1:
return False
return True


def make_multi_device_array(shape, partitioned):
"""Creates a partially- or fully-replicated array."""
num_devices = len(jax.devices())
assert num_devices >= 4
assert is_pow_of_two(num_devices)
mesh = jax.sharding.Mesh(
np.asarray(jax.devices()).reshape((2, num_devices // 2)),
('x', 'y'),
)
if partitioned:
# partially-replicated (partitioned dimension 0 along mesh axis x)
spec = jax.sharding.PartitionSpec('x')
num_partitions = 2
num_replicas = num_devices // 2
else:
# fully-replicated
spec = jax.sharding.PartitionSpec()
num_partitions = 1
num_replicas = num_devices
sharding = jax.sharding.NamedSharding(mesh, spec)

key = jax.random.PRNGKey(0)
x = jax.random.normal(jax.random.PRNGKey(0), shape)
data = jax.device_put(x, sharding)

return data, num_partitions, num_replicas


@parameterized.product(partitioned=[False, True])
class ReplicaSlicesTest(parameterized.TestCase):

def test_get_replica_slices_single_replica(self, partitioned):
arr, num_partitions, num_replicas = make_multi_device_array(
(64, 64),
partitioned=partitioned,
)

# Using an addressable replica_id yields that replica.
for replica_id in range(num_replicas):
rslices = replica_slices.get_replica_slices(
arr,
replica_id=replica_id
).replica_slices
self.assertEqual(len(rslices), num_partitions)
for rslice in rslices:
self.assertEqual(rslice.replica_id, replica_id)

# Omitting replica_id yields _some_ replica.
rslices = replica_slices.get_replica_slices(
arr,
replica_id=None
).replica_slices
self.assertEqual(len(rslices), num_partitions)
for rslice in rslices:
self.assertEqual(rslice.replica_id, rslices[0].replica_id)

# Using an unaddressable replica_id yields nothing.
rslices = replica_slices.get_replica_slices(
arr,
replica_id=-1,
).replica_slices
self.assertEqual(len(rslices), 0)

def test_transfer(self, partitioned):
arr, num_partitions, num_replicas = make_multi_device_array(
(64, 64),
partitioned=partitioned,
)
replica0_shards = [
shard
for shard in arr.addressable_shards
if shard.replica_id == 0
]

rslices = replica_slices.transfer_arrays_to_host(
[arr],
replica_id=0
)[0].replica_slices
self.assertEqual(len(rslices), num_partitions)
self.assertEqual(len(rslices), len(replica0_shards))

index_start = lambda x: x.index[0].start or 0
rslices = sorted(rslices, key=index_start)
replica0_shards = sorted(replica0_shards, key=index_start)

for rslice, replica0_shard in zip(rslices, replica0_shards):
self.assertTrue(rslice.is_on_host)
self.assertIsInstance(rslice.data, np.ndarray)
self.assertEqual(rslice.index, replica0_shard.index)
np.testing.assert_array_equal(rslice.data, replica0_shard.data)


if __name__ == '__main__':
absltest.main()
Loading

0 comments on commit 5f533d5

Please sign in to comment.