-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[replica-parallel] Add replica slices concept
- Loading branch information
Showing
4 changed files
with
393 additions
and
159 deletions.
There are no files selected for viewing
216 changes: 216 additions & 0 deletions
216
checkpoint/orbax/checkpoint/_src/serialization/replica_slices.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
# Copyright 2024 The Orbax Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import dataclasses | ||
from typing import Callable, Optional, Sequence | ||
|
||
from absl import logging | ||
import jax | ||
import jax.numpy as jnp | ||
import math | ||
import numpy as np | ||
from orbax.checkpoint._src.arrays import fragments | ||
from orbax.checkpoint._src.arrays import numpy_utils | ||
from orbax.checkpoint._src.arrays import types | ||
from orbax.checkpoint._src.multihost import multihost | ||
|
||
|
||
Shape = types.Shape | ||
Index = types.Index | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class ReplicaSlice: | ||
""" | ||
ReplicaSlice represents the part of a jax.Shard that a replica is uniquely | ||
responsible for. A replica slice can be either on-device (backed by a slice of | ||
a single-sharding array) or on-host (backed by a numpy ndarray). | ||
With single-replica checkpointing the entirety of each jax.Shard is owned by | ||
exactly one replica. (Currently the only option.) | ||
""" | ||
|
||
replica_id: int | ||
index: Index | ||
data: jax.Array | np.ndarray | ||
|
||
@property | ||
def is_on_host(self): | ||
return isinstance(self.data, np.ndarray) | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class ReplicaSlices: | ||
""" | ||
ReplicaSlices groups all the sliced data of one jax.Array that a replica is | ||
uniquely responsible for. Slices are either all on-device or all on-host. | ||
""" | ||
|
||
global_shape: Shape | ||
local_shape: Shape | ||
sharding: jax.sharding.Sharding | ||
dtype: np.dtype | ||
is_on_host: bool | ||
replica_slices: list[ReplicaSlice] | ||
|
||
def __post_init__(self): | ||
assert all( | ||
rslice.is_on_host == self.is_on_host | ||
for rslice in self.replica_slices | ||
), f'inconsistent is_on_host in {self!r}' | ||
|
||
@property | ||
def nbytes(self) -> int: | ||
slice_nbytes = math.prod(self.local_shape) * self.dtype.itemsize | ||
return slice_nbytes * len(self.replica_slices) | ||
|
||
def to_fragments(self) -> fragments.Fragments: | ||
assert self.is_on_host | ||
result = fragments.Fragments( | ||
shape=self.global_shape, | ||
dtype=self.dtype, | ||
fragments=[ | ||
fragments.Fragment( | ||
index=numpy_utils.resolve_slice( | ||
rslice.index, | ||
self.global_shape | ||
), | ||
value=rslice.data, | ||
) | ||
for rslice in self.replica_slices | ||
], | ||
) | ||
if result.fragments: | ||
fragments.validate_fragments_can_be_stacked(result) | ||
if not result.is_degenerate(): | ||
assert self.local_shape == result.fragments[0].shape | ||
return result | ||
|
||
|
||
def get_replica_slices( | ||
arr: jax.Array, | ||
replica_id: Optional[int], | ||
) -> ReplicaSlices: | ||
"""Returns the replica slices a given replica is responsible for. | ||
Does not transfer allocate or transfer any data.""" | ||
Result = tuple[list[ReplicaSlice], Shape] | ||
shard0 = arr.addressable_shards[0] | ||
|
||
# single-replica: a single replica saves an entire shard. | ||
def pick_single_replica() -> Result: | ||
# Omitting the replica id just picks the first addressable shard's replica | ||
# id so that the process writes each of its addressable shards exactly | ||
# once. (This is the desired behavior for local checkpointing.) | ||
target_replica_id = replica_id or shard0.replica_id | ||
rslices = [ | ||
ReplicaSlice( | ||
replica_id=shard.replica_id, | ||
index=shard.index, | ||
data=shard.data, | ||
) | ||
for shard in arr.addressable_shards | ||
if shard.replica_id == target_replica_id | ||
] | ||
local_shape = shard0.data.shape | ||
return rslices, local_shape | ||
|
||
shards_info = ', '.join( | ||
[ | ||
f'Shard(index={shard.index}, replica_id={shard.replica_id})' | ||
for shard in arr.addressable_shards | ||
] | ||
) | ||
logging.vlog( | ||
1, | ||
'[process=%d] get_replica_slices: replica_id=%d, shards=[%s]', | ||
multihost.process_index(), | ||
replica_id, | ||
shards_info, | ||
) | ||
|
||
# In order for all processes to agree on the right serialization metadata | ||
# we want to compute the correct local shape regardless of whether there | ||
# are any replica slices to save locally. | ||
rslices, local_shape = pick_single_replica() | ||
return ReplicaSlices( | ||
global_shape=arr.shape, | ||
local_shape=local_shape, | ||
sharding=arr.sharding, | ||
dtype=arr.dtype, | ||
is_on_host=False, | ||
replica_slices=rslices, | ||
) | ||
|
||
|
||
def transfer_arrays_to_host( | ||
arrays: Sequence[jax.Array], | ||
replica_id: Optional[int], | ||
*, | ||
enable_pinned_host_transfer: bool = True, | ||
) -> Sequence[ReplicaSlices]: | ||
""" | ||
Transfers jax.Arrays to host memory and returns all the fragments to be | ||
serialized by the given replica, along with local shape. Blocks until | ||
completion. | ||
""" | ||
|
||
def use_pinned_host_transfer(device): | ||
has_pinned_host = any( | ||
m.kind == 'pinned_host' for m in device.addressable_memories() | ||
) | ||
return ( | ||
enable_pinned_host_transfer | ||
and has_pinned_host | ||
and jax._src.config.enable_memories.value # pylint: disable=protected-access | ||
) | ||
|
||
def async_transfer_slice(rslice: ReplicaSlice) -> tuple[ReplicaSlice, jax.Array]: | ||
assert not rslice.is_on_host | ||
index = rslice.index | ||
data = rslice.data | ||
device = data.device | ||
# Start the asynchronous device-to-host copy | ||
if use_pinned_host_transfer(device): | ||
# If available, transfer to pinned host memory | ||
data = jax.device_put( | ||
data, | ||
jax.sharding.SingleDeviceSharding(device, memory_kind='pinned_host'), | ||
) | ||
else: | ||
data.copy_to_host_async() | ||
return rslice, data | ||
|
||
# Gather the replica slices to be saved for each array. | ||
rslices_per_array = [get_replica_slices(arr, replica_id) for arr in arrays] | ||
# Kick off transfers for all replica slices to be saved. | ||
transfers_per_array = [ | ||
[async_transfer_slice(rslice) for rslice in rslices.replica_slices] | ||
for rslices in rslices_per_array | ||
] | ||
# Wait for all the transferred data to be ready. | ||
return [ | ||
dataclasses.replace( | ||
rslices, | ||
is_on_host=True, | ||
replica_slices=[ | ||
dataclasses.replace( | ||
rslice_on_device, | ||
# Conversion to numpy arrays forces block_until_ready. | ||
data=np.asarray(data), | ||
) | ||
for rslice_on_device, data in transfers | ||
], | ||
) | ||
for rslices, transfers in zip(rslices_per_array, transfers_per_array) | ||
] |
123 changes: 123 additions & 0 deletions
123
checkpoint/orbax/checkpoint/_src/serialization/replica_slices_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# Copyright 2024 The Orbax Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from absl.testing import absltest | ||
from absl.testing import parameterized | ||
import jax | ||
import numpy as np | ||
from orbax.checkpoint._src.serialization import replica_slices | ||
|
||
|
||
def is_pow_of_two(n): | ||
while n > 1: | ||
n, rem = divmod(n, 2) | ||
if rem == 1: | ||
return False | ||
return True | ||
|
||
|
||
def make_multi_device_array(shape, partitioned): | ||
"""Creates a partially- or fully-replicated array.""" | ||
num_devices = len(jax.devices()) | ||
assert num_devices >= 4 | ||
assert is_pow_of_two(num_devices) | ||
mesh = jax.sharding.Mesh( | ||
np.asarray(jax.devices()).reshape((2, num_devices // 2)), | ||
('x', 'y'), | ||
) | ||
if partitioned: | ||
# partially-replicated (partitioned dimension 0 along mesh axis x) | ||
spec = jax.sharding.PartitionSpec('x') | ||
num_partitions = 2 | ||
num_replicas = num_devices // 2 | ||
else: | ||
# fully-replicated | ||
spec = jax.sharding.PartitionSpec() | ||
num_partitions = 1 | ||
num_replicas = num_devices | ||
sharding = jax.sharding.NamedSharding(mesh, spec) | ||
|
||
key = jax.random.PRNGKey(0) | ||
x = jax.random.normal(jax.random.PRNGKey(0), shape) | ||
data = jax.device_put(x, sharding) | ||
|
||
return data, num_partitions, num_replicas | ||
|
||
|
||
@parameterized.product(partitioned=[False, True]) | ||
class ReplicaSlicesTest(parameterized.TestCase): | ||
|
||
def test_get_replica_slices_single_replica(self, partitioned): | ||
arr, num_partitions, num_replicas = make_multi_device_array( | ||
(64, 64), | ||
partitioned=partitioned, | ||
) | ||
|
||
# Using an addressable replica_id yields that replica. | ||
for replica_id in range(num_replicas): | ||
rslices = replica_slices.get_replica_slices( | ||
arr, | ||
replica_id=replica_id | ||
).replica_slices | ||
self.assertEqual(len(rslices), num_partitions) | ||
for rslice in rslices: | ||
self.assertEqual(rslice.replica_id, replica_id) | ||
|
||
# Omitting replica_id yields _some_ replica. | ||
rslices = replica_slices.get_replica_slices( | ||
arr, | ||
replica_id=None | ||
).replica_slices | ||
self.assertEqual(len(rslices), num_partitions) | ||
for rslice in rslices: | ||
self.assertEqual(rslice.replica_id, rslices[0].replica_id) | ||
|
||
# Using an unaddressable replica_id yields nothing. | ||
rslices = replica_slices.get_replica_slices( | ||
arr, | ||
replica_id=-1, | ||
).replica_slices | ||
self.assertEqual(len(rslices), 0) | ||
|
||
def test_transfer(self, partitioned): | ||
arr, num_partitions, num_replicas = make_multi_device_array( | ||
(64, 64), | ||
partitioned=partitioned, | ||
) | ||
replica0_shards = [ | ||
shard | ||
for shard in arr.addressable_shards | ||
if shard.replica_id == 0 | ||
] | ||
|
||
rslices = replica_slices.transfer_arrays_to_host( | ||
[arr], | ||
replica_id=0 | ||
)[0].replica_slices | ||
self.assertEqual(len(rslices), num_partitions) | ||
self.assertEqual(len(rslices), len(replica0_shards)) | ||
|
||
index_start = lambda x: x.index[0].start or 0 | ||
rslices = sorted(rslices, key=index_start) | ||
replica0_shards = sorted(replica0_shards, key=index_start) | ||
|
||
for rslice, replica0_shard in zip(rslices, replica0_shards): | ||
self.assertTrue(rslice.is_on_host) | ||
self.assertIsInstance(rslice.data, np.ndarray) | ||
self.assertEqual(rslice.index, replica0_shard.index) | ||
np.testing.assert_array_equal(rslice.data, replica0_shard.data) | ||
|
||
|
||
if __name__ == '__main__': | ||
absltest.main() |
Oops, something went wrong.