Skip to content

Commit

Permalink
add ODE benchmarks (pymc-devs#3730)
Browse files Browse the repository at this point in the history
* add ODE benchmarks

* set cores=2 and return ess per second

* change timeout to 10 minutes
  • Loading branch information
Dpananos authored and ColCarroll committed Dec 17, 2019
1 parent 5d60f8c commit 436a6c4
Showing 1 changed file with 55 additions and 0 deletions.
55 changes: 55 additions & 0 deletions benchmarks/benchmarks/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,58 @@ def track_glm_hierarchical_ess(self, step):


CompareMetropolisNUTSSuite.track_glm_hierarchical_ess.unit = 'Effective samples per second'


class DifferentialEquationSuite:
"""Implements ode examples to keep up with benchmarking them."""

timeout = 600
timer = timeit.default_timer

def track_1var_2par_ode_ess(self):
def freefall(y, t, p):
return 2.0 * p[1] - p[0] * y[0]

# Times for observation
times = np.arange(0, 10, 0.5)
y = np.array([
-2.01,
9.49,
15.58,
16.57,
27.58,
32.26,
35.13,
38.07,
37.36,
38.83,
44.86,
43.58,
44.59,
42.75,
46.9,
49.32,
44.06,
49.86,
46.48,
48.18
]).reshape(-1, 1)

ode_model = pm.ode.DifferentialEquation(func=freefall, times=times, n_states=1, n_theta=2, t0=0)
with pm.Model() as model:
# Specify prior distributions for some of our model parameters
sigma = pm.HalfCauchy("sigma", 1)
gamma = pm.Lognormal("gamma", 0, 1)
# If we know one of the parameter values, we can simply pass the value.
ode_solution = ode_model(y0=[0], theta=[gamma, 9.8])
# The ode_solution has a shape of (n_times, n_states)
Y = pm.Normal("Y", mu=ode_solution, sd=sigma, observed=y)

t0 = time.time()
trace = pm.sample(500, tune=1000, chains=2, cores=2, random_seed=0)
tot = time.time() - t0
ess = pm.ess(trace)
return np.mean([ess.sigma, ess.gamma]) / tot


DifferentialEquationSuite.track_1var_2par_ode_ess.unit = 'Effective samples per second'

0 comments on commit 436a6c4

Please sign in to comment.