Skip to content

Commit

Permalink
pybamm-team#1082 make event calculations faster
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Mar 16, 2021
1 parent e71a4ee commit 15411c9
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 135 deletions.
6 changes: 3 additions & 3 deletions examples/scripts/DFN.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
pybamm.set_logging_level("INFO")

# load model
model = pybamm.lithium_ion.DFN() # {"operating mode": "power"})
model = pybamm.lithium_ion.SPM() # {"operating mode": "power"})
# create geometry
geometry = model.default_geometry

# load parameter values and process model and geometry
param = model.default_parameter_values
# param.update({"Power function [W]": 3.5}, check_already_exists=False)
param["Current function [A]"] /= 10
# param["Current function [A]"] /= 10
param.process_geometry(geometry)
param.process_model(model)

Expand All @@ -29,7 +29,7 @@
disc.process_model(model)

# solve model
t_eval = np.linspace(0, 5000 * 10, 100)
t_eval = np.linspace(0, 4000, 100)
solver = pybamm.CasadiSolver(mode="safe", atol=1e-6, rtol=1e-6)
solution1 = solver.solve(model, t_eval)
solver = pybamm.CasadiSolver(mode="fast with events", atol=1e-6, rtol=1e-6)
Expand Down
6 changes: 3 additions & 3 deletions examples/scripts/experimental_protocols/cccv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
"Discharge at C/5 for 10 hours or until 3.3 V",
"Rest for 1 hour",
"Charge at 1 A until 4.1 V",
# "Hold at 4.1 V until 50 mA",
# "Rest for 1 hour",
"Hold at 4.1 V until 50 mA",
"Rest for 1 hour",
),
]
# * 3
* 3
)
model = pybamm.lithium_ion.SPM()
sim = pybamm.Simulation(
Expand Down
1 change: 1 addition & 0 deletions pybamm/models/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class EventType(Enum):
TERMINATION = 0
DISCONTINUITY = 1
INTERPOLANT_EXTRAPOLATION = 2
SWITCH = 3


class Event:
Expand Down
17 changes: 17 additions & 0 deletions pybamm/models/full_battery_models/base_battery_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,23 @@ def set_voltage_variables(self):
)
)

# Cut-off open-circuit voltage (for event switch with casadi 'fast with events'
# mode)
self.events.append(
pybamm.Event(
"Minimum open circuit voltage",
ocv - self.param.voltage_low_cut,
pybamm.EventType.SWITCH,
)
)
self.events.append(
pybamm.Event(
"Maximum open circuit voltage",
ocv - self.param.voltage_high_cut,
pybamm.EventType.SWITCH,
)
)

# Power
I_dim = self.variables["Current [A]"]
self.variables.update({"Terminal power [W]": I_dim * V_dim})
Expand Down
38 changes: 19 additions & 19 deletions pybamm/models/submodels/particle/base_particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,22 +144,22 @@ def _get_standard_flux_variables(self, N_s, N_s_xav):

return variables

def set_events(self, variables):
c_s_surf = variables[self.domain + " particle surface concentration"]
tol = 1e-4

self.events.append(
pybamm.Event(
"Minumum " + self.domain.lower() + " particle surface concentration",
pybamm.min(c_s_surf) - tol,
pybamm.EventType.TERMINATION,
)
)

self.events.append(
pybamm.Event(
"Maximum " + self.domain.lower() + " particle surface concentration",
(1 - tol) - pybamm.max(c_s_surf),
pybamm.EventType.TERMINATION,
)
)
# def set_events(self, variables):
# c_s_surf = variables[self.domain + " particle surface concentration"]
# tol = 1e-4

# self.events.append(
# pybamm.Event(
# "Minumum " + self.domain.lower() + " particle surface concentration",
# pybamm.min(c_s_surf) - tol,
# pybamm.EventType.TERMINATION,
# )
# )

# self.events.append(
# pybamm.Event(
# "Maximum " + self.domain.lower() + " particle surface concentration",
# (1 - tol) - pybamm.max(c_s_surf),
# pybamm.EventType.TERMINATION,
# )
# )
4 changes: 2 additions & 2 deletions pybamm/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,10 @@ def set_up_experiment(self, model, experiment):
if op[1] in ["A", "C"]:
# Current control: max simulation time: 3 * max simulation time
# based on C-rate
dt = 3 / abs(Crate)
dt = 3 / abs(Crate) * 3600 # seconds
else:
# max simulation time: 1 week
dt = 7 * 24 * 3600
dt = 7 * 24 * 3600 # seconds
self._experiment_times.append(dt)

# Set up model for experiment
Expand Down
30 changes: 14 additions & 16 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,26 +374,24 @@ def report(string):
if event.event_type == pybamm.EventType.TERMINATION:

terminate_events_eval.append(event_eval)
elif event.event_type == pybamm.EventType.SWITCH:
# Save some events to casadi_terminate_events for the 'fast with
# events' mode of the casadi solver
# see #1082
k = 20
if "voltage" in event.name.lower():
init_sign = float(
np.sign(event_eval(0, model.y0, inputs_stacked))
)
# We create a sigmoid for each event which will multiply the
# rhs. Doing * 2 - 1 ensures that when the event is crossed,
# the sigmoid is zero. Hence the rhs is zero and the solution
# stays constant for the rest of the simulation period
# We can then cut off the part after the event was crossed
event_sigmoid = (
pybamm.sigmoid(0, init_sign * event.expression, k) * 2 - 1
)
event_casadi = process(
event_sigmoid, "event", use_jacobian=False
)[0]
casadi_terminate_events.append(event_casadi)
init_sign = float(np.sign(event_eval(0, model.y0, inputs_stacked)))
# We create a sigmoid for each event which will multiply the
# rhs. Doing * 2 - 1 ensures that when the event is crossed,
# the sigmoid is zero. Hence the rhs is zero and the solution
# stays constant for the rest of the simulation period
# We can then cut off the part after the event was crossed
event_sigmoid = (
pybamm.sigmoid(0, init_sign * event.expression, k) * 2 - 1
)
event_casadi = process(event_sigmoid, "event", use_jacobian=False)[
0
]
casadi_terminate_events.append(event_casadi)
elif event.event_type == pybamm.EventType.INTERPOLANT_EXTRAPOLATION:
interpolant_extrapolation_events_eval.append(event_eval)

Expand Down
108 changes: 74 additions & 34 deletions pybamm/solvers/casadi_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pybamm
import numpy as np
from scipy.interpolate import interp1d
from scipy.optimize import brentq
from scipy.optimize import root_scalar


class CasadiSolver(pybamm.BaseSolver):
Expand Down Expand Up @@ -308,7 +308,7 @@ def _solve_for_event(self, coarse_solution, init_event_signs):
"these bounds.".format(extrap_event_names)
)

def find_t_event(sol, interpolant_kind):
def find_t_event(sol, typ):

# Check most recent y to see if any events have been crossed
if model.terminate_events_eval:
Expand All @@ -332,56 +332,94 @@ def find_t_event(sol, interpolant_kind):
event_ind = np.where(new_event_signs != init_event_signs)[0]
active_events = [model.terminate_events_eval[i] for i in event_ind]

# create interpolant to evaluate y in the current integration
# window
y_sol = interp1d(sol.t, sol.y, kind=interpolant_kind)

# loop over events to compute the time at which they were triggered
t_events = [None] * len(active_events)
for i, event in enumerate(active_events):
# Implement our own bisection algorithm for speed
# This is used to find the time range in which the event is triggered
# Evaluations of the "event" function are (relatively) expensive
init_event_sign = init_event_signs[event_ind[i]][0]

def event_fun(t):
# We take away 1e-5 to deal with the case where the event sits
# exactly on zero, as can happen when the event switch is used
# (fast with events mode)
return init_event_sign * event(t, y_sol(t), inputs) - 1e-5
f_eval = {}

if np.isnan(event_fun(sol.t[-1])[0]):
# bracketed search fails if f(a) or f(b) is NaN, so we
# need to find the times for which we can evaluate the event
times = [t for t in sol.t if event_fun(t)[0] == event_fun(t)[0]]
else:
times = sol.t
# skip if sign hasn't changed
if np.sign(event_fun(times[0])) != np.sign(event_fun(times[-1])):
t_events[i] = brentq(lambda t: event_fun(t), times[0], times[-1])
else:
t_events[i] = np.nan
def f(idx):
idx = int(idx)
try:
return f_eval[idx]
except KeyError:
# We take away 1e-5 to deal with the case where the event sits
# exactly on zero, as can happen when the event switch is used
# (fast with events mode)
f_eval[idx] = (
init_event_sign * event(sol.t[idx], sol.y[:, idx], inputs)
- 1e-5
)
return f_eval[idx]

def integer_bisect():
a_n = 0
b_n = len(sol.t) - 1
for _ in range(len(sol.t)):
if a_n + 1 == b_n:
assert f(a_n) > 0 and f(b_n) < 0
return (a_n, b_n)
m_n = (a_n + b_n) // 2
f_m_n = f(m_n)
if np.isnan(f_m_n):
a_n = a_n
b_n = m_n
elif f_m_n < 0:
a_n = a_n
b_n = m_n
elif f_m_n > 0:
a_n = m_n
b_n = b_n

event_idx_lower, event_idx_upper = integer_bisect()
if typ == "window":
return (event_idx_lower, event_idx_upper), None
elif typ == "exact":
# Linear interpolation between the two indices to find the root time
# We could do cubic interpolation here instead but it would be
# slower
t_lower = sol.t[event_idx_lower]
t_upper = sol.t[event_idx_upper]
event_lower = abs(f(event_idx_lower))
event_upper = abs(f(event_idx_upper))

t_events[i] = (event_lower * t_upper + event_upper * t_lower) / (
event_lower + event_upper
)

# t_event is the earliest event triggered
t_event = np.nanmin(t_events)
y_event = y_sol(t_event)
# t_event is the earliest event triggered
t_event = np.nanmin(t_events)
# create interpolant to evaluate y in the current integration
# window
y_sol = interp1d(sol.t, sol.y, kind="linear")
y_event = y_sol(t_event)

return t_event, y_event
return t_event, y_event

# Find the interval in which the event was triggered
t_event_coarse, _ = find_t_event(coarse_solution, "linear")
idx_window_event, _ = find_t_event(coarse_solution, "window")

# Return the existing solution if no events have been triggered
if t_event_coarse is None:
if idx_window_event is None:
# Flag "final time" for termination
coarse_solution.termination = "final time"
return coarse_solution

# If events have been triggered, we solve for a dense window in the interval
# where the event was triggered, then find the precise location of the event
event_idx = np.where(coarse_solution.t > t_event_coarse)[0][0] - 1
t_window_event = coarse_solution.t[event_idx : event_idx + 2]
event_idx = idx_window_event[0]

# Solve again with a more dense t_window, starting from the start of the
# Solve again with a more dense idx_window, starting from the start of the
# window where the event was triggered
t_window_event_dense = np.linspace(t_window_event[-2], t_window_event[-1], 100)
t_window_event_dense = np.linspace(
coarse_solution.t[idx_window_event[0]],
coarse_solution.t[idx_window_event[1]],
100,
)
if self.mode in ["safe", "fast with events"]:
self.create_integrator(model, inputs, t_window_event_dense)

Expand All @@ -391,7 +429,7 @@ def event_fun(t):
)

# Find the exact time at which the event was triggered
t_event, y_event = find_t_event(dense_step_sol, "cubic")
t_event, y_event = find_t_event(dense_step_sol, "exact")

# Return solution truncated at the first coarse event time
# Also assign t_event
Expand All @@ -406,7 +444,9 @@ def event_fun(t):
y_event[:, np.newaxis],
"event",
)
solution.integration_time = coarse_solution.integration_time
solution.integration_time = (
coarse_solution.integration_time + dense_step_sol.integration_time
)

# Flag "True" for termination
return solution
Expand Down
58 changes: 0 additions & 58 deletions test.py

This file was deleted.

0 comments on commit 15411c9

Please sign in to comment.