Skip to content

Commit

Permalink
Add Layout support
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696893888
  • Loading branch information
ChromeHearts authored and Orbax Authors committed Nov 15, 2024
1 parent 577ebf2 commit 863452b
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 12 deletions.
1 change: 1 addition & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ that contain utilities to perform de/serialization for `RootMetadata` and
`StepMetadata`.
- `ReplicaSlice`/`ReplicaSlices` construct to facilitate saving replicated
arrays.
- Added restoring with custom jax.experimental.layout.Layout support

### Changed
- Refactor metadata/tree_test.py and move common test types to
Expand Down
25 changes: 21 additions & 4 deletions checkpoint/orbax/checkpoint/_src/serialization/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from absl import logging
import humanize
import jax
from jax.experimental import layout
import jax.numpy as jnp
import numpy as np
from orbax.checkpoint._src.arrays import fragments
Expand All @@ -48,8 +49,9 @@
]


Shape = types.Shape
Index = types.Index
Layout = layout.Layout
Shape = types.Shape


async def create_async_array_from_callback(
Expand Down Expand Up @@ -426,6 +428,7 @@ async def _read_and_device_put_shard(
dtype: jnp.dtype,
requested_domain: ts.IndexDomain,
restricted_domain: ts.IndexDomain,
dll: Optional[layout.DeviceLocalLayout],
) -> jax.Array:
"""Reads a single shard from TensorStore and places it on device."""
# This maybe needed because the shape the array was saved with is smaller
Expand All @@ -445,7 +448,9 @@ async def _read_and_device_put_shard(
# make this work.
if out.dtype == jnp.int4:
out = jnp.asarray(out) # type: ignore
return jax.device_put(out, jax.sharding.SingleDeviceSharding(device))
return jax.device_put(
out, Layout(dll, jax.sharding.SingleDeviceSharding(device))
)


async def _read_array_index_callback(
Expand All @@ -457,6 +462,7 @@ async def _read_array_index_callback(
dtype: jnp.dtype,
byte_limiter: ByteLimiter,
strict: bool,
ddl: Optional[layout.DeviceLocalLayout],
) -> jax.Array:
"""Callback that reads an array index and places on device."""
if strict and t.shape != shape:
Expand All @@ -479,12 +485,13 @@ async def _read_array_index_callback(
dtype,
requested_domain,
restricted_domain,
ddl,
)
return result


async def async_deserialize(
user_in_sharding: jax.sharding.Sharding,
user_in_sharding: jax.sharding.Sharding | Layout,
tensorstore_spec: Union[ts.Spec, Dict[str, Any]],
global_shape: Optional[Sequence[int]] = None,
dtype: Optional[jnp.dtype] = None,
Expand All @@ -496,11 +503,20 @@ async def async_deserialize(
"""Reads an array using TensorStore."""
byte_limiter = byte_limiter or get_byte_limiter()
context = context or ts_utils.get_ts_context(use_ocdbt=False)
in_sharding = user_in_sharding
in_sharding = (
user_in_sharding.sharding
if isinstance(user_in_sharding, Layout)
else user_in_sharding
)
if not isinstance(in_sharding, jax.sharding.Sharding):
raise ValueError(
'sharding passed to deserialization should be specified, concrete and'
f' an instance of `jax.sharding.Sharding`. Got {in_sharding}')
dll = (
user_in_sharding.device_local_layout
if isinstance(user_in_sharding, Layout)
else None
)
t = await ts.open(
tensorstore_spec,
open=True,
Expand All @@ -520,5 +536,6 @@ async def async_deserialize(
dtype=dtype,
byte_limiter=byte_limiter,
strict=strict,
ddl=dll,
),
)
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from absl.testing import parameterized
import jax
from jax import dtypes as _dtypes
from jax.experimental import layout
import jax.numpy as jnp
import numpy as np
from orbax.checkpoint import future
Expand All @@ -38,6 +39,8 @@
GSPMDSharding = jax.sharding.GSPMDSharding
NamedSharding = jax.sharding.NamedSharding
P = jax.sharding.PartitionSpec
DLL = layout.DeviceLocalLayout
Layout = layout.Layout

jax.config.update('jax_enable_x64', True)

Expand Down Expand Up @@ -598,6 +601,40 @@ def test_odd_resharding(self):
for i, shard in enumerate(restored.addressable_shards):
self.assertArraysEqual(np.asarray(shard.data), np.arange(4) + (i * 4))

def test_load_with_layout(self):
mesh = create_global_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(32).reshape(8, 4)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)

out_layout = (
jax.jit(lambda x: x.T, out_shardings=Layout(DLL.AUTO))
.lower(arr)
.compile()
.output_layouts
)
self.assertEqual(
arr.layout.device_local_layout.major_to_minor,
out_layout.device_local_layout.major_to_minor[::-1],
)

ckpt_dir = pathlib.Path(self.create_tempdir('ckpt').full_path)
ckpt_path = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/first').full_path)
tspecs = jax.tree.map(serialization.get_tensorstore_spec, [ckpt_path])

serialize(
[arr],
tspecs,
)

(out,) = deserialize([out_layout], tspecs)

self.assertEqual(out.layout, out_layout)
self.assertIsInstance(out, jax.Array)
self.assertArraysEqual(out, np_inp)
for s in out.addressable_shards:
self.assertArraysEqual(s.data, np_inp[s.index])


if __name__ == '__main__':
absltest.main()
19 changes: 11 additions & 8 deletions checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from absl import logging
from etils import epath
import jax
from jax.experimental import layout
import jax.numpy as jnp
import numpy as np
from orbax.checkpoint import future
Expand All @@ -49,6 +50,7 @@
import tensorstore as ts


Layout = layout.Layout
Shape = types.Shape
Scalar = Union[int, float, np.number]
Metadata = value_metadata.Metadata
Expand Down Expand Up @@ -1041,13 +1043,12 @@ class ArrayRestoreArgs(RestoreArgs):
mesh_axes:
The mesh_axes that the array should be restored as. Cannot be None.
sharding:
`jax.sharding.Sharding` object which takes precedence over mesh and
mesh_axes if provided. Otherwise, mesh and mesh_axes will be used to
construct a NamedSharding object OR `ShardingMetadata` which is an orbax
representation of `jax.sharding.Sharding` that stores the same properties
but does not require accessing real devices.
global_shape:
The global shape that the array should be restored into. If not
`jax.sharding.Sharding`, `ShardingMetadata`, or `Layout` object which takes
precedence over mesh and mesh_axes if provided. Otherwise, mesh and mesh_axes
will be used to construct a NamedSharding object OR `ShardingMetadata` which
is an orbax representation of `jax.sharding.Sharding` that stores the same
properties but does not require accessing real devices.
global_shape: The global shape that the array should be restored into. If not
provided, the shape will be restored as written. Presently, arbitrary shape
transformations are not supported (for example, reshaping to different
dimensions). Padding and truncating are supported. When the global_shape is
Expand All @@ -1064,7 +1065,9 @@ class ArrayRestoreArgs(RestoreArgs):
restore_type: Optional[Any] = jax.Array
mesh: Optional[jax.sharding.Mesh] = None
mesh_axes: Optional[jax.sharding.PartitionSpec] = None
sharding: Optional[Union[jax.sharding.Sharding, ShardingMetadata]] = None
sharding: Optional[Union[jax.sharding.Sharding, ShardingMetadata, Layout]] = (
None
)
global_shape: Optional[Tuple[int, ...]] = None
strict: bool = True

Expand Down

0 comments on commit 863452b

Please sign in to comment.