Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixing bugs 323 and 324 #327

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions diffrax/global_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,11 +438,15 @@ def _linear_interpolation(
if replace_nans_at_start is None:
y0 = ys[0]
else:
y0 = jnp.broadcast_to(replace_nans_at_start, ys[0].shape)
y0 = jnp.broadcast_to(
jnp.where(jnp.isnan(ys[0]), replace_nans_at_start, ys[0]), ys[0].shape
)
ys = ys.at[0].set(y0)

_, (next_ts, next_ys) = lax.scan(
_interpolation_reverse, (ts[-1], ys[-1]), (ts, ys), reverse=True
)

if fill_forward_nans_at_end:
next_ys = fill_forward(next_ys)
_, ys = lax.scan(
Expand Down Expand Up @@ -657,18 +661,26 @@ def _backward_hermite_coefficients(
]:
ts = left_broadcast_to(ts, ys.shape)

if replace_nans_at_start is None:
y0 = ys[0]
else:
y0 = jnp.broadcast_to(
jnp.where(jnp.isnan(ys[0]), replace_nans_at_start, ys[0]), ys[0].shape
)
ys = ys.at[0].set(y0)

_, (next_ts, next_ys) = lax.scan(
_interpolation_reverse, (ts[-1], ys[-1]), (ts[1:], ys[1:]), reverse=True
_interpolation_reverse, (ts[-1], ys[-1]), (ts, ys), reverse=True
)

if fill_forward_nans_at_end:
next_ys = fill_forward(next_ys)

next_ts = next_ts[1:]
next_ys = next_ys[1:]

t0 = ts[0]
if replace_nans_at_start is None:
y0 = ys[0]
else:
y0 = jnp.broadcast_to(replace_nans_at_start, ys[0].shape)

if deriv0 is None:
deriv0 = (next_ys[0] - y0) / (next_ts[0] - t0)
else:
Expand Down
4 changes: 3 additions & 1 deletion diffrax/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def fill_forward(
if replace_nans_at_start is None:
y0 = ys[0]
else:
y0 = jnp.broadcast_to(replace_nans_at_start, ys[0].shape)
y0 = jnp.broadcast_to(
jnp.where(jnp.isnan(ys[0]), replace_nans_at_start, ys[0]), ys[0].shape
)
_, ys = lax.scan(_fill_forward, y0, ys)
return ys

Expand Down
2 changes: 1 addition & 1 deletion diffrax/step_size_controller/adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def adapt_step_size(
# ε_n = atol + norm(y) * rtol with y on the nth step
# r_n = norm(y_error) with y_error on the nth step
# δ_{n,m} = norm(y_error / (atol + norm(y) * rtol))^(-1) with y_error on the nth
# step and y on the mth step
# step and y on the mth step
# β_1 = pcoeff + icoeff + dcoeff
# β_2 = -(pcoeff + 2 * dcoeff)
# β_3 = dcoeff
Expand Down
46 changes: 46 additions & 0 deletions test/test_global_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,3 +394,49 @@ def test_dense_interpolation_vmap(solver, getkey):
diffrax.Ralston: 1e-3,
}.get(type(solver), 1e-6)
assert shaped_allclose(derivs, true_derivs, atol=deriv_tol, rtol=deriv_tol)


@pytest.mark.parametrize("mode", ["linear", "rectilinear", "cubic"])
@pytest.mark.parametrize(
"nans, expected",
[
(
jnp.array([0, 3, 4, 6, 9]),
jnp.array([20.0, 1.0, 2.0, 23.0, 24.0, 5.0, 26.0, 7.0, 8.0, 29.0]),
),
(jnp.arange(0, 10, 1), jnp.arange(20, 30, 1)),
],
)
@pytest.mark.parametrize("init_nan", [True, False])
def test_replace_nans_at_start(mode, nans, expected, init_nan):
ts = jnp.linspace(0, 1, 15)
if init_nan:
ys = jnp.full((15, 10), jnp.nan)
else:
ys = jrandom.normal(jrandom.PRNGKey(0), (15, 10))
ys = ys.at[0, :].set(jnp.arange(0, 10, 1))
nan_ys = ys.at[0, nans].set(jnp.nan)
replace_nans_at_start = jnp.arange(20, 30, 1)

if mode == "cubic":
coeffs = diffrax.backward_hermite_coefficients(
ts,
nan_ys,
replace_nans_at_start=replace_nans_at_start,
fill_forward_nans_at_end=True,
)
interp = diffrax.CubicInterpolation(ts, coeffs)
elif mode == "linear":
interp = diffrax.linear_interpolation(
ts,
nan_ys,
replace_nans_at_start=replace_nans_at_start,
fill_forward_nans_at_end=True,
)
interp = diffrax.LinearInterpolation(ts, interp)
elif mode == "rectilinear":
ts, coeffs = diffrax.rectilinear_interpolation(
ts, nan_ys, replace_nans_at_start=replace_nans_at_start
)
interp = diffrax.LinearInterpolation(ts, coeffs)
assert shaped_allclose(interp.evaluate(0), expected)
Loading