Skip to content

Commit

Permalink
Add validation to prevent loading an array index that was never writt…
Browse files Browse the repository at this point in the history
…en to.

PiperOrigin-RevId: 698945828
  • Loading branch information
cpgaffney1 authored and Orbax Authors committed Nov 22, 2024
1 parent 04e4ae7 commit 646f1e6
Show file tree
Hide file tree
Showing 9 changed files with 132 additions and 39 deletions.
3 changes: 3 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [0.10.0] - 2024-11-22

### Added
- Add `RootMetadata` and `StepMetadata` classes as ways for the user to
interface with checkpoint metadata at various levels.
Expand All @@ -24,6 +26,7 @@ arrays.
saved cooperatively by multiple hosts.
- [Experimental Feature] Support `NamedTuple` and `Tuple` nodes in PyTree
metadata.
- Add validation to prevent loading an array index that was never written to.

### Changed
- Refactor metadata/tree_test.py and move common test types to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def _get_param_infos(
ocdbt_target_data_file_size: Optional[int] = None,
enable_pinned_host_transfer: bool = True,
byte_limiter: Optional[serialization.ByteLimiter] = None,
raise_array_data_missing_error: bool = True,
) -> PyTree:
"""Returns parameter information for elements in `item`.
Expand All @@ -362,6 +363,7 @@ def _get_param_infos(
OCDBT data file.
enable_pinned_host_transfer: See ParamInfo docs.
byte_limiter: ByteLimiter object.
raise_array_data_missing_error: See documentation in ParamInfo.
Returns:
A PyTree matching `item` of ParamInfo.
Expand Down Expand Up @@ -389,6 +391,7 @@ def _param_info(name, value):
value_typestr=types.get_param_typestr(
value, self._type_handler_registry
),
raise_array_data_missing_error=raise_array_data_missing_error,
)

return jax.tree.map(
Expand Down Expand Up @@ -686,6 +689,9 @@ class TrainState:
if internal_tree_metadata.use_zarr3 is not None
else self._use_zarr3
)
raise_array_data_missing_error = (
internal_tree_metadata.store_array_data_equal_to_fill_value
)
del internal_tree_metadata
# Prep for restore.
if item is None:
Expand All @@ -701,6 +707,7 @@ class TrainState:
directory=directory,
use_ocdbt=type_handlers.is_ocdbt_checkpoint(directory),
use_zarr3=use_zarr3,
raise_array_data_missing_error=raise_array_data_missing_error,
)
# Begin restore.
tree_memory_size, restored_item = asyncio_utils.run_sync(
Expand Down
15 changes: 12 additions & 3 deletions checkpoint/orbax/checkpoint/_src/metadata/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
_KEY_METADATA_KEY = 'key_metadata'
_VALUE_METADATA_KEY = 'value_metadata'
_USE_ZARR3 = 'use_zarr3'
_STORE_ARRAY_DATA_EQUAL_TO_FILL_VALUE = 'store_array_data_equal_to_fill_value'
_VALUE_METADATA_TREE = 'value_metadata_tree'


Expand Down Expand Up @@ -197,6 +198,7 @@ class InternalTreeMetadata:

tree_metadata_entries: List[InternalTreeMetadataEntry]
use_zarr3: bool
store_array_data_equal_to_fill_value: bool
pytree_metadata_options: PyTreeMetadataOptions
value_metadata_tree: PyTree | None = None

Expand Down Expand Up @@ -248,6 +250,7 @@ def build(
return InternalTreeMetadata(
tree_metadata_entries,
use_zarr3,
ts_utils.STORE_ARRAY_DATA_EQUAL_TO_FILL_VALUE,
pytree_metadata_options,
value_metadata_tree,
)
Expand All @@ -274,6 +277,7 @@ def to_json(self) -> Dict[str, Any]:
...
},
_USE_ZARR3: True/False,
_STORE_ARRAY_DATA_EQUAL_TO_FILL_VALUE: True,
_VALUE_METADATA_TREE: '{
"mu_nu": {
"category": "namedtuple",
Expand Down Expand Up @@ -329,6 +333,9 @@ def to_json(self) -> Dict[str, Any]:
{},
),
_USE_ZARR3: self.use_zarr3,
_STORE_ARRAY_DATA_EQUAL_TO_FILL_VALUE: (
self.store_array_data_equal_to_fill_value
),
}
# TODO: b/365169723 - Support versioned evolution of metadata storage.
if (
Expand All @@ -351,9 +358,10 @@ def from_json(
),
) -> InternalTreeMetadata:
"""Returns an InternalTreeMetadata instance from its JSON representation."""
use_zarr3 = False
if _USE_ZARR3 in json_dict:
use_zarr3 = json_dict[_USE_ZARR3]
use_zarr3 = json_dict.get(_USE_ZARR3, False)
store_array_data_equal_to_fill_value = json_dict.get(
_STORE_ARRAY_DATA_EQUAL_TO_FILL_VALUE, False
)

tree_metadata_entries = []
for keypath, json_tree_metadata_entry in json_dict[
Expand All @@ -376,6 +384,7 @@ def from_json(
use_zarr3=use_zarr3,
pytree_metadata_options=pytree_metadata_options,
value_metadata_tree=value_metadata_tree,
store_array_data_equal_to_fill_value=store_array_data_equal_to_fill_value,
)

def as_nested_tree(self) -> Dict[str, Any]:
Expand Down
88 changes: 55 additions & 33 deletions checkpoint/orbax/checkpoint/_src/serialization/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Array serialization and deserialization.
TODO(b/348434669): De-fork when possible.
"""
"""Array serialization and deserialization."""

import asyncio
from collections.abc import Awaitable
from collections.abc import Awaitable, Mapping
import contextlib
import functools
import os
Expand Down Expand Up @@ -54,19 +51,23 @@
Shape = types.Shape


def _get_device_to_index_map(
global_shape: Shape, sharding: jax.sharding.Sharding
) -> Mapping[jax.Device, Index]:
return sharding.devices_indices_map(global_shape)


async def create_async_array_from_callback(
global_shape: Shape,
inp_sharding: jax.sharding.Sharding,
sharding: jax.sharding.Sharding,
data_callback: Callable[[Index, jax.Device], Awaitable[jax.Array]],
) -> jax.Array:
device_to_index_map = inp_sharding.devices_indices_map(global_shape)
addressable_da = inp_sharding._addressable_device_assignment # pylint: disable=protected-access
device_to_index_map = _get_device_to_index_map(global_shape, sharding)
addressable_da = sharding._addressable_device_assignment # pylint: disable=protected-access
future_arrays = [data_callback(device_to_index_map[d], d)
for d in addressable_da]
dbs = await asyncio.gather(*future_arrays)
return jax.make_array_from_single_device_arrays(
global_shape, inp_sharding, dbs
)
return jax.make_array_from_single_device_arrays(global_shape, sharding, dbs)


def _get_metadata(arr: jax.Array, local_shape: Shape):
Expand Down Expand Up @@ -462,44 +463,65 @@ async def _read_array_index_callback(
index: Index,
device: jax.Device,
t: ts.TensorStore,
shape: Sequence[int],
new_shard_shape: Sequence[int],
shape: Shape,
new_shard_shape: Shape,
dtype: jnp.dtype,
byte_limiter: ByteLimiter,
strict: bool,
ddl: Optional[layout.DeviceLocalLayout],
) -> jax.Array:
"""Callback that reads an array index and places on device."""
if strict and t.shape != shape:
raise ValueError(
f'Requested shape: {shape} is not compatible with the stored shape:'
f' {t.shape}. Truncating/padding is disabled by setting of'
' `strict=True`. When using standard Orbax APIs, this behavior can be'
' modified by specifying `strict=False` in `ArrayRestoreArgs` for any'
' array in which padding/truncation is desired.'
)
requested_domain = ts.IndexTransform(input_shape=shape)[index].domain
restricted_domain = t.domain.intersect(requested_domain)
for sl in index:
if sl.step is not None and sl.step != 1:
raise ValueError(
f'Non-contiguous domain for index: {index} not supported. Found:'
f' {sl.step}'
)

if strict:
if t.shape == shape:
domain = ts.IndexDomain(shape=shape)[ts.d[:][index]]
requested_domain = domain
restricted_domain = domain
else:
raise ValueError(
f'Requested shape: {shape} is not compatible with the stored shape:'
f' {t.shape}. Truncating/padding is disabled by setting of'
' `strict=True`. When using standard Orbax APIs, this behavior can be'
' modified by specifying `strict=False` in `ArrayRestoreArgs` for any'
' array in which padding/truncation is desired.'
)
else:
requested_domain = ts.IndexTransform(input_shape=shape)[index].domain
restricted_domain = t.domain.intersect(requested_domain)

requested_bytes = estimate_read_memory_footprint(t, restricted_domain)
# Limit the bytes read for every shard.
async with reserved_bytes(byte_limiter, requested_bytes):
result = await _read_and_device_put_shard(
device,
t,
new_shard_shape,
dtype,
requested_domain,
restricted_domain,
ddl,
)
try:
result = await _read_and_device_put_shard(
device,
t,
new_shard_shape,
dtype,
requested_domain,
restricted_domain,
ddl,
)
except BaseException as e:
raise Exception( # pylint: disable=broad-exception-raised
f'Encountered error while reading array index: {index}. See full'
f' TensorStore details: {t.spec}.'
) from e
return result


async def async_deserialize(
user_in_sharding: jax.sharding.Sharding | Layout,
tensorstore_spec: Union[ts.Spec, Dict[str, Any]],
global_shape: Optional[Sequence[int]] = None,
global_shape: Optional[Shape] = None,
dtype: Optional[jnp.dtype] = None,
*,
byte_limiter: Optional[ByteLimiter] = None,
context: Optional[ts.Context] = None,
assume_metadata: bool = False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,33 @@ def test_load_with_layout(self):
for s in out.addressable_shards:
self.assertArraysEqual(s.data, np_inp[s.index])

def test_incomplete_write(self):
data = np.arange(8)
chunk_len = 4
global_mesh = create_global_mesh((8,), 'x')
sharding = NamedSharding(global_mesh, P(None))
tspec = ts_utils.ArrayWriteSpec(
self.ckpt_dir.as_posix(),
'a',
global_shape=data.shape,
write_shape=(chunk_len,),
dtype=data.dtype,
use_ocdbt=False,
).json
t = ts.open(
ts.Spec(tspec),
create=True,
open=True,
).result()
t[:chunk_len].write(data[:chunk_len]).result()

# Enable raising error for incomplete chunk.
tspec['fill_missing_data_reads'] = False
with self.assertRaisesRegex(
Exception, 'Encountered error while reading array index'
):
deserialize([sharding], [tspec])


if __name__ == '__main__':
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@

_GCS_PATH_RE = r'^gs://([^/]*)/(.*)$'

# Even if the data is equal to the fill value, we still want to write it
# to the checkpoint. This results in unnecessary writes in some edge
# cases, but it allows us to verify that data was actually written when
# later restoring.
# Must match `store_data_equal_to_fill_value` property in Orbax
# metadata.
STORE_ARRAY_DATA_EQUAL_TO_FILL_VALUE = True


JsonSpec: TypeAlias = dict[str, Any]
Shape: TypeAlias = types.Shape
Expand Down Expand Up @@ -344,6 +352,7 @@ def __init__(
'kvstore': kvstore_tspec,
'recheck_cached_data': False,
'recheck_cached_metadata': False,
'store_data_equal_to_fill_value': STORE_ARRAY_DATA_EQUAL_TO_FILL_VALUE,
}
if metadata_key is not None:
tspec['metadata_key'] = metadata_key
Expand Down
16 changes: 14 additions & 2 deletions checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,10 @@ async def _assert_parameter_files_exist(
def _get_json_tspec(
info: types.ParamInfo,
use_ocdbt: bool,
*,
process_index: Optional[Union[int, str]] = None,
metadata_key: Optional[str] = None,
raise_array_data_missing_error: bool = True,
) -> Dict[str, Any]:
"""Gets Tensorstore spec in JSON format."""
if info.path is None:
Expand All @@ -124,6 +126,8 @@ def _get_json_tspec(
'kvstore': kvstore_tspec,
'recheck_cached_data': False,
'recheck_cached_metadata': False,
# Raise error if data is missing.
'fill_missing_data_reads': not raise_array_data_missing_error,
}
if metadata_key is not None:
tspec['metadata_key'] = metadata_key
Expand All @@ -136,12 +140,14 @@ def get_json_tspec_read(
info: types.ParamInfo,
use_ocdbt: bool,
metadata_key: Optional[str] = None,
raise_array_data_missing_error: bool = True,
):
"""Gets Tensorstore spec for reading."""
return _get_json_tspec(
info,
use_ocdbt=use_ocdbt,
metadata_key=metadata_key,
raise_array_data_missing_error=raise_array_data_missing_error,
)


Expand Down Expand Up @@ -561,7 +567,10 @@ def _get_json_tspec_read(
) -> Dict[str, Any]:
"""Gets Tensorstore spec for reading."""
return get_json_tspec_read(
info, use_ocdbt=use_ocdbt, metadata_key=self._metadata_key
info,
use_ocdbt=use_ocdbt,
metadata_key=self._metadata_key,
raise_array_data_missing_error=info.raise_array_data_missing_error,
)

def typestr(self) -> str:
Expand Down Expand Up @@ -884,7 +893,10 @@ def _get_json_tspec_read(
) -> Dict[str, Any]:
"""Gets Tensorstore spec for reading."""
return get_json_tspec_read(
info, use_ocdbt=use_ocdbt, metadata_key=self._metadata_key
info,
use_ocdbt=use_ocdbt,
metadata_key=self._metadata_key,
raise_array_data_missing_error=info.raise_array_data_missing_error,
)

def typestr(self) -> str:
Expand Down
4 changes: 4 additions & 0 deletions checkpoint/orbax/checkpoint/_src/serialization/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ class ParamInfo:
enable_pinned_host_transfer:
True by default. If False, disables transfer to pinned host when copying
from device to host, regardless of the presence of pinned host memory.
raise_array_data_missing_error:
Only used for restoring. See documentation in `tensorstore_utils.py`. Comes
from tree metadata and should be the same across all parameters.
"""

name: Optional[str] = None
Expand All @@ -112,6 +115,7 @@ class ParamInfo:
ts_context: Optional[ts.Context] = None
value_typestr: Optional[str] = None
enable_pinned_host_transfer: bool = True
raise_array_data_missing_error: bool = True


@dataclasses.dataclass
Expand Down
2 changes: 1 addition & 1 deletion checkpoint/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dependencies = [
'jax >= 0.4.34',
'numpy',
'pyyaml',
'tensorstore >= 0.1.60',
'tensorstore >= 0.1.68',
'nest_asyncio',
'protobuf',
'humanize',
Expand Down

0 comments on commit 646f1e6

Please sign in to comment.