Skip to content

Commit

Permalink
Merge pull request jax-ml#11155 from jakevdp:x64-ode-test
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 455714054
  • Loading branch information
jax authors committed Jun 17, 2022
2 parents 5318df6 + 5e97011 commit 5236140
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 13 deletions.
26 changes: 15 additions & 11 deletions jax/experimental/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from jax import core
from jax import custom_derivatives
from jax import lax
from jax._src.numpy.util import _promote_dtypes_inexact
from jax._src.util import safe_map, safe_zip
from jax.flatten_util import ravel_pytree
from jax.tree_util import tree_leaves, tree_map
Expand All @@ -59,10 +60,11 @@ def interp_fit_dopri(y0, y1, k, dt):
6025192743 / 30085553152 / 2, 0, 51252292925 / 65400821598 / 2,
-2691868925 / 45128329728 / 2, 187940372067 / 1594534317056 / 2,
-1776094331 / 19743644256 / 2, 11237099 / 235043384 / 2], dtype=y0.dtype)
y_mid = y0 + dt * jnp.dot(dps_c_mid, k)
y_mid = y0 + dt.astype(y0.dtype) * jnp.dot(dps_c_mid, k)
return jnp.asarray(fit_4th_order_polynomial(y0, y1, y_mid, k[0], k[-1], dt))

def fit_4th_order_polynomial(y0, y1, y_mid, dy0, dy1, dt):
dt = dt.astype(y0.dtype)
a = -2.*dt*dy0 + 2.*dt*dy1 - 8.*y0 - 8.*y1 + 16.*y_mid
b = 5.*dt*dy0 - 3.*dt*dy1 + 18.*y0 + 14.*y1 - 32.*y_mid
c = -4.*dt*dy0 + dt*dy1 - 11.*y0 - 5.*y1 + 16.*y_mid
Expand All @@ -74,15 +76,17 @@ def initial_step_size(fun, t0, y0, order, rtol, atol, f0):
# Algorithm from:
# E. Hairer, S. P. Norsett G. Wanner,
# Solving Ordinary Differential Equations I: Nonstiff Problems, Sec. II.4.
y0, f0 = _promote_dtypes_inexact(y0, f0)
dtype = y0.dtype

scale = atol + jnp.abs(y0) * rtol
d0 = jnp.linalg.norm(y0 / scale)
d1 = jnp.linalg.norm(f0 / scale)
d0 = jnp.linalg.norm(y0 / scale.astype(dtype))
d1 = jnp.linalg.norm(f0 / scale.astype(dtype))

h0 = jnp.where((d0 < 1e-5) | (d1 < 1e-5), 1e-6, 0.01 * d0 / d1)

y1 = y0 + h0 * f0
y1 = y0 + h0.astype(dtype) * f0
f1 = fun(y1, t0 + h0)
d2 = jnp.linalg.norm((f1 - f0) / scale) / h0
d2 = jnp.linalg.norm((f1 - f0) / scale.astype(dtype)) / h0

h1 = jnp.where((d1 <= 1e-15) & (d2 <= 1e-15),
jnp.maximum(1e-6, h0 * 1e-3),
Expand Down Expand Up @@ -110,15 +114,15 @@ def runge_kutta_step(func, y0, f0, t0, dt):

def body_fun(i, k):
ti = t0 + dt * alpha[i-1]
yi = y0 + dt * jnp.dot(beta[i-1, :], k)
yi = y0 + dt.astype(f0.dtype) * jnp.dot(beta[i-1, :], k)
ft = func(yi, ti)
return k.at[i, :].set(ft)

k = jnp.zeros((7, f0.shape[0]), f0.dtype).at[0, :].set(f0)
k = lax.fori_loop(1, 7, body_fun, k)

y1 = dt * jnp.dot(c_sol, k) + y0
y1_error = dt * jnp.dot(c_error, k)
y1 = dt.astype(f0.dtype) * jnp.dot(c_sol, k) + y0
y1_error = dt.astype(f0.dtype) * jnp.dot(c_error, k)
f1 = k[-1]
return y1, f1, y1_error, k

Expand All @@ -130,7 +134,7 @@ def abs2(x):

def mean_error_ratio(error_estimate, rtol, atol, y0, y1):
err_tol = atol + rtol * jnp.maximum(jnp.abs(y0), jnp.abs(y1))
err_ratio = error_estimate / err_tol
err_ratio = error_estimate / err_tol.astype(error_estimate.dtype)
return jnp.sqrt(jnp.mean(abs2(err_ratio)))

def optimal_step_size(last_step, mean_error_ratio, safety=0.9, ifactor=10.0,
Expand Down Expand Up @@ -205,7 +209,7 @@ def body_fun(state):
_, *carry = lax.while_loop(cond_fun, body_fun, [0] + carry)
_, _, t, _, last_t, interp_coeff = carry
relative_output_time = (target_t - last_t) / (t - last_t)
y_target = jnp.polyval(interp_coeff, relative_output_time)
y_target = jnp.polyval(interp_coeff, relative_output_time.astype(interp_coeff.dtype))
return carry, y_target

f0 = func_(y0, ts[0])
Expand Down
7 changes: 5 additions & 2 deletions tests/ode_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def test_complex_odeint(self):
# https://github.com/google/jax/issues/8757

def dy_dt(y, t, alpha):
return alpha * y * jnp.exp(-t)
return alpha * y * jnp.exp(-t).astype(y.dtype)

def f(y0, ts, alpha):
return odeint(dy_dt, y0, ts, alpha).real
Expand All @@ -248,7 +248,10 @@ def f(y0, ts, alpha):
ts = jnp.linspace(0., 1., 11)
tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3

jtu.check_grads(f, (y0, ts, alpha), modes=["rev"], order=2, atol=tol, rtol=tol)
# During the backward pass, this ravels all parameters into a single array
# such that dtype promotion is unavoidable.
with jax.numpy_dtype_promotion('standard'):
jtu.check_grads(f, (y0, ts, alpha), modes=["rev"], order=2, atol=tol, rtol=tol)


if __name__ == '__main__':
Expand Down

0 comments on commit 5236140

Please sign in to comment.