From 174bec41f3152b95065e8ca598cc751ca524d57c Mon Sep 17 00:00:00 2001 From: John Brittain Date: Mon, 18 Sep 2023 09:31:38 +0100 Subject: [PATCH] Streamline base_solver process() --- pybamm/solvers/base_solver.py | 47 +++++++++++++++++------------------ 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index c0ee4c32fc..0a412f6a06 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -1421,11 +1421,7 @@ def _set_up_model_inputs(self, model, inputs): def process( - symbol, - name, - vars_for_processing, - use_jacobian=None, - return_jacp_stacked=None + symbol, name, vars_for_processing, use_jacobian=None, return_jacp_stacked=None ): """ Parameters @@ -1615,17 +1611,28 @@ def jacp(*args, **kwargs): "CasADi" ) ) - # WARNING, jacp for convert_to_format=casadi does not return a dict - # instead it returns multiple return values, one for each param - # TODO: would it be faster to do the jacobian wrt pS_casadi_stacked? - jacp = casadi.Function( - name + "_jacp", - [t_casadi, y_and_S, p_casadi_stacked], - [ - casadi.densify(casadi.jacobian(casadi_expression, p_casadi[pname])) - for pname in model.calculate_sensitivities - ], - ) + # Compute derivate wrt p-stacked (can be passed to solver to + # compute sensitivities online) + if return_jacp_stacked: + jacp = casadi.Function( + f"d{name}_dp", + [t_casadi, y_casadi, p_casadi_stacked], + [casadi.jacobian(casadi_expression, p_casadi_stacked)], + ) + else: + # WARNING, jacp for convert_to_format=casadi does not return a dict + # instead it returns multiple return values, one for each param + # TODO: would it be faster to do the jacobian wrt pS_casadi_stacked? + jacp = casadi.Function( + name + "_jacp", + [t_casadi, y_and_S, p_casadi_stacked], + [ + casadi.densify( + casadi.jacobian(casadi_expression, p_casadi[pname]) + ) + for pname in model.calculate_sensitivities + ], + ) if use_jacobian: report(f"Calculating jacobian for {name} using CasADi") @@ -1648,14 +1655,6 @@ def jacp(*args, **kwargs): [t_casadi, y_and_S, p_casadi_stacked, v], [jac_action_casadi], ) - # Compute derivate wrt p-stacked (can be passed to solver to - # compute sensitivities online) - if return_jacp_stacked: - jacp = casadi.Function( - f"d{name}_dp", - [t_casadi, y_casadi, p_casadi_stacked], - [casadi.jacobian(casadi_expression, p_casadi_stacked)], - ) else: jac = None jac_action = None