Skip to content

Commit

Permalink
Bug fixes:
Browse files Browse the repository at this point in the history
 - Callbacks were lost when using _TupleFunc/_ReverseFunc
 - Erroneous warnings in adjoint mode (don't raise for internally added null callback)
  • Loading branch information
slishak committed Jul 27, 2022
1 parent fa18529 commit 72e3f8c
Showing 1 changed file with 21 additions and 17 deletions.
38 changes: 21 additions & 17 deletions torchdiffeq/_impl/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ def _check_inputs(func, y0, t, rtol, atol, method, options, event_fn, SOLVERS):
# Combine event functions if the output is multivariate.
event_fn = combine_event_functions(event_fn, t[0], y0)

# Keep reference to original func as passed in
original_func = func

# Normalise to tensor (non-tupled) input
shapes = None
is_tuple = not isinstance(y0, torch.Tensor)
Expand Down Expand Up @@ -302,40 +305,41 @@ def _norm(tensor):
# ~Backward compatibility

# Add perturb argument to func.
wrapped_func = _PerturbFunc(func)
func = _PerturbFunc(func)

# Add callbacks to wrapped_func
callback_names = set()
for callback_name in _all_callback_names:
try:
callback = getattr(func, callback_name)
callback = getattr(original_func, callback_name)
except AttributeError:
setattr(wrapped_func, callback_name, _null_callback)
setattr(func, callback_name, _null_callback)
else:
callback_names.add(callback_name)
# At the moment all callbacks have the arguments (t0, y0, dt).
# These will need adjusting on a per-callback basis if that changes in the future.
if is_tuple:
def callback(t0, y0, dt, _callback=callback):
y0 = _flat_to_shape(y0, (), shapes)
return _callback(t0, y0, dt)
if t_is_reversed:
def callback(t0, y0, dt, _callback=callback):
return _callback(-t0, y0, dt)
setattr(wrapped_func, callback_name, callback)
if callback is not _null_callback:
callback_names.add(callback_name)
# At the moment all callbacks have the arguments (t0, y0, dt).
# These will need adjusting on a per-callback basis if that changes in the future.
if is_tuple:
def callback(t0, y0, dt, _callback=callback):
y0 = _flat_to_shape(y0, (), shapes)
return _callback(t0, y0, dt)
if t_is_reversed:
def callback(t0, y0, dt, _callback=callback):
return _callback(-t0, y0, dt)
setattr(func, callback_name, callback)
for callback_name in _all_adjoint_callback_names:
try:
callback = getattr(func, callback_name)
callback = getattr(original_func, callback_name)
except AttributeError:
pass
else:
setattr(wrapped_func, callback_name, callback)
setattr(func, callback_name, callback)

invalid_callbacks = callback_names - SOLVERS[method].valid_callbacks()
if len(invalid_callbacks) > 0:
warnings.warn("Solver '{}' does not support callbacks {}".format(method, invalid_callbacks))

return shapes, wrapped_func, y0, t, rtol, atol, method, options, event_fn, t_is_reversed
return shapes, func, y0, t, rtol, atol, method, options, event_fn, t_is_reversed


class _StitchGradient(torch.autograd.Function):
Expand Down

0 comments on commit 72e3f8c

Please sign in to comment.