Skip to content

Commit

Permalink
Streamline base_solver process()
Browse files Browse the repository at this point in the history
  • Loading branch information
jsbrittain committed Sep 18, 2023
1 parent f381e31 commit 174bec4
Showing 1 changed file with 23 additions and 24 deletions.
47 changes: 23 additions & 24 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,11 +1421,7 @@ def _set_up_model_inputs(self, model, inputs):


def process(
symbol,
name,
vars_for_processing,
use_jacobian=None,
return_jacp_stacked=None
symbol, name, vars_for_processing, use_jacobian=None, return_jacp_stacked=None
):
"""
Parameters
Expand Down Expand Up @@ -1615,17 +1611,28 @@ def jacp(*args, **kwargs):
"CasADi"
)
)
# WARNING, jacp for convert_to_format=casadi does not return a dict
# instead it returns multiple return values, one for each param
# TODO: would it be faster to do the jacobian wrt pS_casadi_stacked?
jacp = casadi.Function(
name + "_jacp",
[t_casadi, y_and_S, p_casadi_stacked],
[
casadi.densify(casadi.jacobian(casadi_expression, p_casadi[pname]))
for pname in model.calculate_sensitivities
],
)
# Compute derivate wrt p-stacked (can be passed to solver to
# compute sensitivities online)
if return_jacp_stacked:
jacp = casadi.Function(
f"d{name}_dp",
[t_casadi, y_casadi, p_casadi_stacked],
[casadi.jacobian(casadi_expression, p_casadi_stacked)],
)
else:
# WARNING, jacp for convert_to_format=casadi does not return a dict
# instead it returns multiple return values, one for each param
# TODO: would it be faster to do the jacobian wrt pS_casadi_stacked?
jacp = casadi.Function(
name + "_jacp",
[t_casadi, y_and_S, p_casadi_stacked],
[
casadi.densify(
casadi.jacobian(casadi_expression, p_casadi[pname])
)
for pname in model.calculate_sensitivities
],
)

if use_jacobian:
report(f"Calculating jacobian for {name} using CasADi")
Expand All @@ -1648,14 +1655,6 @@ def jacp(*args, **kwargs):
[t_casadi, y_and_S, p_casadi_stacked, v],
[jac_action_casadi],
)
# Compute derivate wrt p-stacked (can be passed to solver to
# compute sensitivities online)
if return_jacp_stacked:
jacp = casadi.Function(
f"d{name}_dp",
[t_casadi, y_casadi, p_casadi_stacked],
[casadi.jacobian(casadi_expression, p_casadi_stacked)],
)
else:
jac = None
jac_action = None
Expand Down

0 comments on commit 174bec4

Please sign in to comment.