Skip to content

Commit

Permalink
changed the state_checkpointer implementation to a class (cleaner and…
Browse files Browse the repository at this point in the history
… easier to use)
  • Loading branch information
ecignoni committed Dec 21, 2023
1 parent 259a3c9 commit 31fdc52
Showing 1 changed file with 51 additions and 47 deletions.
98 changes: 51 additions & 47 deletions gpx/optimizers/scipy_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,55 +111,59 @@ def scipy_minimize_derivs(
# Callbacks
# ============================================================================

# we use a dictionary storing a separate counter for each optimization
# (we don't want the counter to be overridden)
_STATE_CHECKPOINT_COUNTER = {}

class StateCheckpointer:
"""state checkpointer callback
def state_checkpointer(
state: ModelState, chk_file: Optional[str] = "optim_chk.npz"
) -> Callable[ArrayLike]:
"""state checkpointer callable
Generates a function that can be passed as `callback` to scipy_minimize.
This callable reconstructs the ModelState and saves it to a .npz file.
You can then load back the model parameters into a model with `model.load`.
Args:
state: model state.
chk_file: name of the checkpoint file. Note that a checkpoint will be
saved for each step of scipy_minimize with a different postfix.
E.g., if chk_file='test.npz', you will obtain 'test.000.npz'
and so on.
Returns:
callback: function that can be passed as `callback` argument to
scipy_minimize
This class can be passed as a 'callback' to scipy_minimize.
When called, it reconstructs the ModelState and saves it to a .npz file.
You can then load back the model parameters into a model with `model.load`
to check the model along the optimization.
You can also use the saved model to perform a restart if something goes
wrong during a long optimization.
"""
# use the hash code of state as key for the checkpoint counter
# WARNING: if you start from the exact same state you override
# the counter

global _STATE_CHECKPOINT_COUNTER

# create counter
hash_code = hash(state)
_STATE_CHECKPOINT_COUNTER[hash_code] = 0

# get the fmt string in order to save the state with a counter
name, ext = os.path.splitext(chk_file)
chk_file = name + ".{:03d}" + ext

def callback(x):
global _STATE_CHECKPOINT_COUNTER
counter = _STATE_CHECKPOINT_COUNTER[hash_code]

def __init__(
self, state: ModelState, chk_file: Optional[str] = "optim_chk.npz"
) -> None:
"""
Args:
state: model state.
chk_file: name of the checkpoint file. Note that a checkpoint will
be saved for each step of scipy_minimize with a
different postfix.
E.g., if chk_file='test.npz', you will obtain
- 'test.000.npz'
- 'test.001.npz'
- ...
and so on.
"""
# store the state
self.state = state
# counter for how many times the callback is called
self.counter = 0
# create the formatted path for saving
self.fmt_chk_file = chk_file
# get the tree definition and the function to unravel the arrays
_, tdef, unravel_fn = ravel_backward_trainables(state.params)
unravel_forward = unravel_forward_trainables(unravel_fn, tdef, state.params)
params = unravel_forward(x)
ustate = state.update(dict(params=params))
ustate.save(chk_file.format(counter))

# update counter
_STATE_CHECKPOINT_COUNTER[hash_code] += 1

return callback
# get the function to reconstruct the parameters from x
self.unravel_forward = unravel_forward_trainables(
unravel_fn, tdef, state.params
)

@property
def fmt_chk_file(self):
return self._fmt_chk_file

@fmt_chk_file.setter
def fmt_chk_file(self, value: str) -> None:
name, ext = os.path.splitext(value)
self._fmt_chk_file = name + ".{:03d}" + ext

def __call__(self, x: ArrayLike) -> None:
params = self.unravel_forward(x)
ustate = self.state.update(dict(params=params))
ustate.save(self.fmt_chk_file.format(self.counter))
self.counter += 1

0 comments on commit 31fdc52

Please sign in to comment.