-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
conv_gru and conv_gru+ode training loop
- Loading branch information
Sudesh S Shetye
committed
Dec 4, 2023
1 parent
f6fdd7a
commit 92b79e1
Showing
104 changed files
with
6,631 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# YAML 1.2 | ||
--- | ||
abstract: | | ||
"This library provides ordinary differential equation (ODE) solvers implemented in PyTorch. Backpropagation through ODE solutions is supported using the adjoint method for constant memory cost. We also allow terminating an ODE solution based on an event function, with exact gradient computed. | ||
As the solvers are implemented in PyTorch, algorithms in this repository are fully supported to run on the GPU." | ||
authors: | ||
- | ||
family-names: Chen | ||
given-names: "Ricky T. Q." | ||
cff-version: "1.1.0" | ||
date-released: 2021-06-02 | ||
license: MIT | ||
message: "PyTorch Implementation of Differentiable ODE Solvers" | ||
repository-code: "https://github.com/rtqichen/torchdiffeq" | ||
title: torchdiffeq | ||
version: "0.2.2" | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Frequently Asked Questions (FAQ) | ||
|
||
**What are good resources to understand how ODEs can be solved?**<br> | ||
*Solving Ordinary Differential Equations I Nonstiff Problems* by Hairer et al.<br> | ||
[ODE solver selection in MatLab](https://blogs.mathworks.com/loren/2015/09/23/ode-solver-selection-in-matlab/)<br> | ||
|
||
**What are the ODE solvers available in this repo?**<br> | ||
|
||
- Adaptive-step: | ||
- `dopri8` Runge-Kutta 7(8) of Dormand-Prince-Shampine | ||
- `dopri5` Runge-Kutta 4(5) of Dormand-Prince **[default]**. | ||
- `bosh3` Runge-Kutta 2(3) of Bogacki-Shampine | ||
- `adaptive_heun` Runge-Kutta 1(2) | ||
|
||
- Fixed-step: | ||
- `euler` Euler method. | ||
- `midpoint` Midpoint method. | ||
- `rk4` Fourth-order Runge-Kutta with 3/8 rule. | ||
- `explicit_adams` Explicit Adams. | ||
- `implicit_adams` Implicit Adams. | ||
|
||
- `scipy_solver`: Wraps a SciPy solver. | ||
|
||
|
||
**What are `NFE-F` and `NFE-B`?**<br> | ||
Number of function evaluations for forward and backward pass. | ||
|
||
**What are `rtol` and `atol`?**<br> | ||
They refer to relative `rtol` and absolute `atol` error tolerance. | ||
|
||
**What is the role of error tolerance in adaptive solvers?**<br> | ||
The basic idea is each adaptive solver can produce an error estimate of the current step, and if the error is greater than some tolerance, then the step is redone with a smaller step size, and this repeats until the error is smaller than the provided tolerance.<br> | ||
[Error Tolerances for Variable-Step Solvers](https://www.mathworks.com/help/simulink/ug/types-of-solvers.html#f11-44943) | ||
|
||
**How is the error tolerance calculated?**<br> | ||
The error tolerance is [calculated]((https://github.com/rtqichen/torchdiffeq/blob/master/torchdiffeq/_impl/misc.py#L74)) as `atol + rtol * norm of current state`, where the norm being used is a mixed L-infinity/RMS norm. | ||
|
||
**Where is the code that computes the error tolerance?**<br> | ||
It is computed [here.](https://github.com/rtqichen/torchdiffeq/blob/c4c9c61c939c630b9b88267aa56ddaaec319cb16/torchdiffeq/_impl/misc.py#L94) | ||
|
||
**How many states must a Neural ODE solver store during a forward pass with the adjoint method?**<br> | ||
The number of states required to be stored in memory during a forward pass is solver dependent. For example, `dopri5` requires 6 intermediate states to be stored. | ||
|
||
**How many function evaluations are there per ODE step on adaptive solvers?**<br> | ||
|
||
- `dopri5`<br> | ||
The `dopri5` ODE solver stores at least 6 evaluations of the ODE, then takes a step using a linear combination of them. The diagram below illustrates it: the evaluations marked with `o` are on the estimated path, the others with `x` are not. The first two are for selecting the initial step size. | ||
|
||
``` | ||
0 1 | 2 3 4 5 6 7 | 8 9 10 12 13 14 | ||
o x | x x x x x o | x x x x x o | ||
``` | ||
|
||
|
||
**How do I obtain evaluations on the estimated path when using an adaptive solver?**<br> | ||
The argument `t` of `odeint` specifies what times should the ODE solver output.<br> | ||
```odeint(func, x0, t=torch.linspace(0, 1, 50))``` | ||
|
||
Note that the ODE solver will always integrate from `min t(0)` to `max t(1)`, and the intermediate values of `t` have no effect on how the ODE the solved. Intermediate values are computed using polynomial interpolation and have very small cost. | ||
|
||
**What non-linearities should I use in my Neural ODE?**<br> | ||
Avoid non-smooth non-linearities such as ReLU and LeakyReLU.<br> | ||
Prefer non-linearities with a theoretically unique adjoint/gradient such as Softplus. | ||
|
||
**Where is backpropagation for the Neural ODE defined?**<br> | ||
It's defined [here](https://github.com/rtqichen/torchdiffeq/blob/master/torchdiffeq/_impl/adjoint.py) if you use the adjoint method `odeint_adjoint`. | ||
|
||
**What are Tableaus?**<br> | ||
Tableaus are ways to describe coefficients for [RK methods](https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta_methods). The particular set of coefficients used on this repo was taken from [here](https://www.ams.org/journals/mcom/1986-46-173/S0025-5718-1986-0815836-3/). | ||
|
||
**How do I install the repo on Windows?**<br> | ||
Try downloading the code directly and just running python setup.py install. | ||
https://stackoverflow.com/questions/52528955/installing-a-python-module-from-github-in-windows-10 | ||
|
||
**What is the most memory-expensive operation during training?**<br> | ||
The most memory-expensive operation is the single [backward call](https://github.com/rtqichen/torchdiffeq/blob/master/torchdiffeq/_impl/adjoint.py#L75) made to the network. | ||
|
||
**My Neural ODE's numerical solution is farther away from the target than the initial value**<br> | ||
Most tricks for initializing residual nets (like zeroing the weights of the last layer) should help for ODEs as well. This will initialize the ODE as an identity. | ||
|
||
|
||
**My Neural ODE takes too long to train**<br> | ||
This might be because you're running on CPU. Being extremely slow on CPU is expected, as training requires evaluating a neural net multiple times. | ||
|
||
|
||
**My Neural ODE produces underflow in dt when using adaptive solvers like `dopri5`**<br> | ||
This is a problem of the ODE becoming stiff, essentially acting too erratic in a region and the step size becomes so close to zero that no progress can be made in the solver. We were able to avoid this with regularization such as weight decay and using "nice" activation functions, but YMMV. Other potential options are just to accept a larger error by increasing `atol`, `rtol`, or by switching to a fixed solver. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# Further documentation | ||
|
||
## Solver options | ||
|
||
Adaptive and fixed solvers all support several options. Also shown are their default values. | ||
|
||
**Adaptive solvers (dopri8, dopri5, bosh3, adaptive_heun):**<br> | ||
For these solvers, `rtol` and `atol` correspond to the tolerances for accepting/rejecting an adaptive step. | ||
|
||
- `first_step=None`: What size the first step of the solver should be; by default this is selected empirically. | ||
|
||
- `safety=0.9, ifactor=10.0, dfactor=0.2`: How the next optimal step size is calculated, see E. Hairer, S. P. Norsett G. Wanner, *Solving Ordinary Differential Equations I: Nonstiff Problems*, Sec. II.4. Roughly speaking, `safety` will try to shrink the step size slightly by this amount, `ifactor` is the most that the step size can grow by, and `dfactor` is the most that it can shrink by. | ||
|
||
- `max_num_steps=2 ** 31 - 1`: The maximum number of steps the solver is allowed to take. | ||
|
||
- `dtype=torch.float64`: what dtype to use for timelike quantities. Setting this to `torch.float32` will improve speed but may produce underflow errors more easily. | ||
|
||
- `step_t=None`: Times that a step must me made to. In particular this is useful when `func` has kinks (derivative discontinuities) at these times, as the solver then does not need to (slowly) discover these for itself. If passed this should be a `torch.Tensor`. | ||
|
||
- `jump_t=None`: Times that a step must be made to, and `func` re-evaluated at. In particular this is useful when `func` has discontinuites at these times, as then the solver knows that the final function evaluation of the previous step is not equal to the first function evaluation of this step. (i.e. the FSAL property does not hold at this point.) If passed this should be a `torch.Tensor`. Note that this may not be efficient when using PyTorch 1.6.0 or earlier. | ||
|
||
- `norm`: What norm to compute the accept/reject criterion with respect to. Given tensor input, this defaults to an RMS norm. Given tupled input, this defaults to computing an RMS norm over each tensor, and then taking a max over the tuple, producing a mixed L-infinity/RMS norm. If passed this should be a function consuming a tensor/tuple with the same shape as `y0`, and return a scalar corresponding to its norm. When passed as part of `adjoint_options`, then the special value `"seminorm"` may be used to zero out the contribution from the parameters, as per the ["Hey, that's not an ODE"](https://arxiv.org/abs/2009.09457) paper. | ||
|
||
**Fixed solvers (euler, midpoint, rk4, explicit_adams, implicit_adams):**<br> | ||
|
||
- `step_size=None`: How large each discrete step should be. If not passed then this defaults to stepping between the values of `t`. Note that if using `t` just to specify the start and end of the regions of integration, then it is very important to specify this argument! It is mutually exclusive with the `grid_constructor` argument, below. | ||
|
||
- `grid_constructor=None`: A more fine-grained way of setting the steps, by setting these particular locations as the locations of the steps. Should be a callable `func, y0, t -> grid`, transforming the arguments `func, y0, t` of `odeint` into the desired grid (which should be a one dimensional tensor). | ||
|
||
- `perturb`: Defaults to False. If True, then automatically add small perturbations to the start and end of each step, so that stepping to discontinuities works. Note that this this may not be efficient when using PyTorch 1.6.0 or earlier. | ||
|
||
Individual solvers also offer certain options. | ||
|
||
**explicit_adams:**<br> | ||
For this solver, `rtol` and `atol` are ignored. This solver also supports: | ||
|
||
- `max_order`: The maximum order of the Adams-Bashforth predictor. | ||
|
||
**implicit_adams:**<br> | ||
For this solver, `rtol` and `atol` correspond to the tolerance for convergence of the Adams-Moulton corrector. This solver also supports: | ||
|
||
- `max_order`: The maximum order of the Adams-Bashforth-Moulton predictor-corrector. | ||
|
||
- `max_iters`: The maximum number of iterations to run the Adams-Moulton corrector for. | ||
|
||
**scipy_solver:**<br> | ||
- `solver`: which SciPy solver to use; corresponds to the `'method'` argument of `scipy.integrate.solve_ivp`. | ||
|
||
## Adjoint options | ||
|
||
The function `odeint_adjoint` offers some adjoint-specific options. | ||
|
||
- `adjoint_rtol`,<br>`adjoint_atol`,<br>`adjoint_method`,<br>`adjoint_options`:<br>The `rtol, atol, method, options` to use for the backward pass. Defaults to the values used for the forward pass. | ||
|
||
- `adjoint_options` has the special key-value pair `{"norm": "seminorm"}` that provides a potentially more efficient adjoint solve when using adaptive step solvers, as described in the ["Hey, that's not an ODE"](https://arxiv.org/abs/2009.09457) paper. | ||
|
||
- `adjoint_params`: The parameters to compute gradients with respect to in the backward pass. Should be a tuple of tensors. Defaults to `tuple(func.parameters())`. | ||
- If passed then `func` does not have to be a `torch.nn.Module`. | ||
- If `func` has no parameters, `adjoint_params=()` must be specified. | ||
|
||
|
||
## Callbacks | ||
|
||
Callbacks can be triggered during the solve. Callbacks should be specified as methods of the `func` argument to `odeint` and `odeint_adjoint`. | ||
|
||
At the moment support for this is minimal: let us know if you'd find additional callbacks useful. | ||
|
||
**callback_step(self, t0, y0, dt):**<br> | ||
This is called immediately before taking a step of size `dt`, at time `t0`, with current solution value `y0`. This is supported by every solver except `scipy_solver`. | ||
|
||
**callback_accept_step(self, t0, y0, dt):**<br> | ||
This is called when accepting a step of size `dt` at time `t0`, with current solution value `y0`. This is supported by the adaptive solvers (dopri8, dopri5, bosh3, adaptive_heun). | ||
|
||
**callback_reject_step(self, t0, y0, dt):**<br> | ||
As `callback_accept_step`, except called when rejecting steps. | ||
|
||
In addition, callbacks can be triggered during the adjoint pass by adding `_adjoint` to the name of any one of the supported callbacks, e.g. `callback_step_adjoint`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2018 Ricky Tian Qi Chen | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
Oops, something went wrong.