Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Allow for Alternative and Custom ODE Solvers. #748

Merged
merged 3 commits into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
ENH: Allow for Alternative and Custom ODE Solvers.
  • Loading branch information
phmbressan committed Dec 7, 2024
commit 26ceab9dd04f5676dd357e134e9dc88bd96c4c5c
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@
"pytest",
"pytz",
"quantile",
"Radau",
"Rdot",
"referece",
"relativetoground",
Expand Down
63 changes: 52 additions & 11 deletions rocketpy/simulation/flight.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np
import simplekml
from scipy import integrate
from scipy.integrate import BDF, DOP853, LSODA, RK23, RK45, OdeSolver, Radau

from ..mathutils.function import Function, funcify_method
from ..mathutils.vector_matrix import Matrix, Vector
Expand All @@ -24,8 +24,19 @@
quaternions_to_spin,
)

ODE_SOLVER_MAP = {
'RK23': RK23,
'RK45': RK45,
'DOP853': DOP853,
'Radau': Radau,
'BDF': BDF,
'LSODA': LSODA,
}
phmbressan marked this conversation as resolved.
Show resolved Hide resolved

class Flight: # pylint: disable=too-many-public-methods

# pylint: disable=too-many-public-methods
# pylint: disable=too-many-instance-attributes
class Flight:
"""Keeps all flight information and has a method to simulate flight.

Attributes
Expand Down Expand Up @@ -506,6 +517,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements
verbose=False,
name="Flight",
equations_of_motion="standard",
ode_solver="LSODA",
):
"""Run a trajectory simulation.

Expand Down Expand Up @@ -581,10 +593,23 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements
more restricted set of equations of motion that only works for
solid propulsion rockets. Such equations were used in RocketPy v0
and are kept here for backwards compatibility.
ode_solver : str, ``scipy.integrate.OdeSolver``, optional
Integration method to use to solve the equations of motion ODE.
Available options are: 'RK23', 'RK45', 'DOP853', 'Radau', 'BDF',
'LSODA' from ``scipy.integrate.solve_ivp``.
Default is 'LSODA', which is recommended for most flights.
A custom ``scipy.integrate.OdeSolver`` can be passed as well.
For more information on the integration methods, see the scipy
documentation [1]_.


Returns
-------
None

References
----------
.. [1] https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.solve_ivp.html
phmbressan marked this conversation as resolved.
Show resolved Hide resolved
"""
# Save arguments
self.env = environment
Expand All @@ -605,6 +630,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements
self.terminate_on_apogee = terminate_on_apogee
self.name = name
self.equations_of_motion = equations_of_motion
self.ode_solver = ode_solver

# Controller initialization
self.__init_controllers()
Expand Down Expand Up @@ -651,15 +677,16 @@ def __simulate(self, verbose):

# Create solver for this flight phase # TODO: allow different integrators
self.function_evaluations.append(0)
phase.solver = integrate.LSODA(

phase.solver = self._solver(
phase.derivative,
t0=phase.t,
y0=self.y_sol,
t_bound=phase.time_bound,
min_step=self.min_time_step,
max_step=self.max_time_step,
rtol=self.rtol,
atol=self.atol,
max_step=self.max_time_step,
min_step=self.min_time_step,
)

# Initialize phase time nodes
Expand Down Expand Up @@ -691,13 +718,14 @@ def __simulate(self, verbose):
for node_index, node in self.time_iterator(phase.time_nodes):
# Determine time bound for this time node
node.time_bound = phase.time_nodes[node_index + 1].t
# NOTE: Setting the time bound and status for the phase solver,
# and updating its internal state for the next integration step.
phase.solver.t_bound = node.time_bound
phase.solver._lsoda_solver._integrator.rwork[0] = phase.solver.t_bound
phase.solver._lsoda_solver._integrator.call_args[4] = (
phase.solver._lsoda_solver._integrator.rwork
)
if self.__is_lsoda:
phase.solver._lsoda_solver._integrator.rwork[0] = (
phase.solver.t_bound
)
phase.solver._lsoda_solver._integrator.call_args[4] = (
phase.solver._lsoda_solver._integrator.rwork
)
phase.solver.status = "running"

# Feed required parachute and discrete controller triggers
Expand Down Expand Up @@ -1185,6 +1213,19 @@ def __init_solver_monitors(self):
self.t = self.solution[-1][0]
self.y_sol = self.solution[-1][1:]

if isinstance(self.ode_solver, OdeSolver):
self._solver = self.ode_solver
else:
try:
self._solver = ODE_SOLVER_MAP[self.ode_solver]
except KeyError as e:
raise ValueError(
f"Invalid ``ode_solver`` input: {self.ode_solver}. "
f"Available options are: {', '.join(ODE_SOLVER_MAP.keys())}"
) from e

self.__is_lsoda = hasattr(self._solver, "_lsoda_solver")

phmbressan marked this conversation as resolved.
Show resolved Hide resolved
def __init_equations_of_motion(self):
"""Initialize equations of motion."""
if self.equations_of_motion == "solid_propulsion":
Expand Down