Skip to content

Commit

Permalink
Add docs on working with PyTrees/arrays. Modify error messages for wh…
Browse files Browse the repository at this point in the history
…en sharding is not provided.

PiperOrigin-RevId: 623537816
  • Loading branch information
cpgaffney1 authored and Orbax Authors committed Apr 10, 2024
1 parent 74adaef commit 470c74c
Show file tree
Hide file tree
Showing 5 changed files with 670 additions and 8 deletions.
4 changes: 4 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Added
- Add path package to export symbols, also add Step rst docs.
- Create composite step `NameFormat`.
- Docs on working with PyTrees/arrays.

## Changed
- Error messages when sharding is not specified.

### Changed
- Improve step lookup error message by adding expected names to it.
Expand Down
7 changes: 7 additions & 0 deletions checkpoint/orbax/checkpoint/standard_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import dataclasses
from typing import Any, List, Optional

from absl import logging
from etils import epath
import jax
from orbax.checkpoint import checkpoint_args
Expand Down Expand Up @@ -162,6 +163,12 @@ def restore(
self._validate_restore_state(args.item)
restore_args = checkpoint_utils.construct_restore_args(args.item)
else:
logging.warning(
'`StandardCheckpointHandler` expects a target tree to be provided for'
' restore. Not doing so is generally UNSAFE unless you know the'
' present topology to be the same one as the checkpoint was saved'
' under.'
)
restore_args = checkpoint_utils.construct_restore_args(
self.metadata(directory)
)
Expand Down
4 changes: 3 additions & 1 deletion checkpoint/orbax/checkpoint/type_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1476,7 +1476,9 @@ async def deserialize(
"Couldn't find sharding info under RestoreArgs. Populating sharding"
' info from sharding file. Please note restoration time will be'
' slightly increased due to reading from file instead of directly'
' from RestoreArgs.'
' from RestoreArgs. Note also that this option is unsafe when'
' restoring on a different topology than the checkpoint was saved'
' with.'
)
if info.name:
tspec_sharding = get_sharding_tensorstore_spec(
Expand Down
Loading

0 comments on commit 470c74c

Please sign in to comment.