Skip to content

Commit

Permalink
3937 events with idaklu output variables - solver edit (pybamm-team#4300
Browse files Browse the repository at this point in the history
)

* Add test that fails

* Remove duplicate attribute assignment

* IDAKLU solver returns additional y_term variable containing the final state vector slice

* Add final state vector slice to python-idaklu, tests pass

* Reduce memory load so y_term is only filled if output variables are specified.
Otherwise empty array.

* Edit changelog

---------

Co-authored-by: Martin Robinson <[email protected]>
  • Loading branch information
pipliggins and martinjrobins authored Jul 30, 2024
1 parent 5408ce3 commit f255c38
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 10 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

- Added additional user-configurable options to the (`IDAKLUSolver`) and adjusted the default values to improve performance. ([#4282](https://github.com/pybamm-team/PyBaMM/pull/4282))

## Bug Fixes

- Fixed bug where IDAKLU solver failed when `output variables` were specified and an event triggered. ([#4300](https://github.com/pybamm-team/PyBaMM/pull/4300))

# [v24.5](https://github.com/pybamm-team/PyBaMM/tree/v24.5) - 2024-07-26

## Features
Expand Down
1 change: 1 addition & 0 deletions pybamm/solvers/c_solvers/idaklu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,5 +187,6 @@ PYBIND11_MODULE(idaklu, m)
.def_readwrite("t", &Solution::t)
.def_readwrite("y", &Solution::y)
.def_readwrite("yS", &Solution::yS)
.def_readwrite("y_term", &Solution::y_term)
.def_readwrite("flag", &Solution::flag);
}
23 changes: 21 additions & 2 deletions pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ Solution IDAKLUSolverOpenMP<ExprSet>::solve(

// set return vectors
int length_of_return_vector = 0;
int length_of_final_sv_slice = 0;
size_t max_res_size = 0; // maximum result size (for common result buffer)
size_t max_res_dvar_dy = 0, max_res_dvar_dp = 0;
if (functions->var_fcns.size() > 0) {
Expand All @@ -414,6 +415,7 @@ Solution IDAKLUSolverOpenMP<ExprSet>::solve(
for (auto& dvar_fcn : functions->dvar_dp_fcns) {
max_res_dvar_dp = std::max(max_res_dvar_dp, size_t(dvar_fcn->out_shape(0)));
}
length_of_final_sv_slice = number_of_states;
}
} else {
// Return full y state-vector
Expand All @@ -425,6 +427,7 @@ Solution IDAKLUSolverOpenMP<ExprSet>::solve(
realtype *yS_return = new realtype[number_of_parameters *
number_of_timesteps *
length_of_return_vector];
realtype *yterm_return = new realtype[length_of_final_sv_slice];

res.resize(max_res_size);
res_dvar_dy.resize(max_res_dvar_dy);
Expand All @@ -451,6 +454,13 @@ Solution IDAKLUSolverOpenMP<ExprSet>::solve(
delete[] vect;
}
);
py::capsule free_yterm_when_done(
yterm_return,
[](void *f) {
realtype *vect = reinterpret_cast<realtype *>(f);
delete[] vect;
}
);

// Initial state (t_i=0)
int t_i = 0;
Expand Down Expand Up @@ -518,6 +528,10 @@ Solution IDAKLUSolverOpenMP<ExprSet>::solve(
t_i += 1;

if (retval == IDA_SUCCESS || retval == IDA_ROOT_RETURN) {
if (functions->var_fcns.size() > 0) {
// store final state slice if outout variables are specified
yterm_return = yval;
}
break;
}
}
Expand All @@ -532,7 +546,7 @@ Solution IDAKLUSolverOpenMP<ExprSet>::solve(
&y_return[0],
free_y_when_done
);
// Note: Ordering of vector is differnet if computing variables vs returning
// Note: Ordering of vector is different if computing variables vs returning
// the complete state vector
np_array yS_ret;
if (functions->var_fcns.size() > 0) {
Expand All @@ -556,8 +570,13 @@ Solution IDAKLUSolverOpenMP<ExprSet>::solve(
free_yS_when_done
);
}
np_array y_term = np_array(
length_of_final_sv_slice,
&yterm_return[0],
free_yterm_when_done
);

Solution sol(retval, t_ret, y_ret, yS_ret);
Solution sol(retval, t_ret, y_ret, yS_ret, y_term);

if (solver_opts.print_stats) {
long nsteps, nrevals, nlinsetups, netfails;
Expand Down
5 changes: 3 additions & 2 deletions pybamm/solvers/c_solvers/idaklu/Solution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@ class Solution
/**
* @brief Constructor
*/
Solution(int retval, np_array t_np, np_array y_np, np_array yS_np)
: flag(retval), t(t_np), y(y_np), yS(yS_np)
Solution(int retval, np_array t_np, np_array y_np, np_array yS_np, np_array y_term_np)
: flag(retval), t(t_np), y(y_np), yS(yS_np), y_term(y_term_np)
{
}

int flag;
np_array t;
np_array y;
np_array yS;
np_array y_term;
};

#endif // PYBAMM_IDAKLU_COMMON_HPP
3 changes: 2 additions & 1 deletion pybamm/solvers/c_solvers/idaklu/python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -478,8 +478,9 @@ Solution solve_python(np_array t_np, np_array y0_np, np_array yp0_np,
std::vector<ptrdiff_t> {number_of_parameters, number_of_timesteps, number_of_states},
&yS_return[0]
);
np_array yterm_ret = np_array(0);

Solution sol(retval, t_ret, y_ret, yS_ret);
Solution sol(retval, t_ret, y_ret, yS_ret, yterm_ret);

return sol;
}
5 changes: 4 additions & 1 deletion pybamm/solvers/idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,7 @@ def _make_iree_function(self, fcn, *args, sparse_index=False):

def fcn(*args):
return fcn_inner(*args)[coo.row, coo.col]

elif coo.nnz != iree_fcn.numel:
iree_fcn.nnz = iree_fcn.numel
iree_fcn.col = list(range(iree_fcn.numel))
Expand Down Expand Up @@ -969,8 +970,10 @@ def _integrate(self, model, t_eval, inputs_dict=None):
if self.output_variables:
# Substitute empty vectors for state vector 'y'
y_out = np.zeros((number_of_timesteps * number_of_states, 0))
y_event = sol.y_term
else:
y_out = sol.y.reshape((number_of_timesteps, number_of_states))
y_event = y_out[-1]

# return sensitivity solution, we need to flatten yS to
# (#timesteps * #states (where t is changing the quickest),)
Expand Down Expand Up @@ -1000,7 +1003,7 @@ def _integrate(self, model, t_eval, inputs_dict=None):
model,
inputs_dict,
np.array([t[-1]]),
np.transpose(y_out[-1])[:, np.newaxis],
np.transpose(y_event)[:, np.newaxis],
termination,
sensitivities=yS_out,
)
Expand Down
4 changes: 0 additions & 4 deletions pybamm/solvers/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,6 @@ def __init__(

self.sensitivities = sensitivities

self._t_event = t_event
self._y_event = y_event
self._termination = termination

# Check no ys are too large
if check_solution:
self.check_ys_are_not_too_large()
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/test_solvers/test_idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,36 @@ def test_bad_jax_evaluator_output_variables(self):
output_variables=["Terminal voltage [V]"],
)

def test_with_output_variables_and_event_termination(self):
model = pybamm.lithium_ion.DFN()
parameter_values = pybamm.ParameterValues("Chen2020")

sim = pybamm.Simulation(
model,
parameter_values=parameter_values,
solver=pybamm.IDAKLUSolver(output_variables=["Terminal voltage [V]"]),
)
sol = sim.solve(np.linspace(0, 3600, 1000))
self.assertEqual(sol.termination, "event: Minimum voltage [V]")

# create an event that doesn't require the state vector
eps_p = model.variables["Positive electrode porosity"]
model.events.append(
pybamm.Event(
"Zero positive electrode porosity cut-off",
pybamm.min(eps_p),
pybamm.EventType.TERMINATION,
)
)

sim3 = pybamm.Simulation(
model,
parameter_values=parameter_values,
solver=pybamm.IDAKLUSolver(output_variables=["Terminal voltage [V]"]),
)
sol3 = sim3.solve(np.linspace(0, 3600, 1000))
self.assertEqual(sol3.termination, "event: Minimum voltage [V]")


if __name__ == "__main__":
print("Add -v for more debug output")
Expand Down

0 comments on commit f255c38

Please sign in to comment.