You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Something I noticed while using diffrax, was that the adaptive solvers where much slower using pmaping integration with the new runtime on CPUs (pmaping is used over sharding for this reason #26586). I adapted the code from the aforementioned issue to also show that pmap-ing is slower on the new runtime.
This example is of course trivial, but represents the core subroutine of adaptive SDE solvers. Currently this can be solved by disabling the new CPU thunk runtime, but I'm just reporting it so hopefully it can be fixed in the future :).
System info (python version, jaxlib version, accelerator, etc.)
multiple jax version, CPU, tested on Mac and colab
The text was updated successfully, but these errors were encountered:
Description
Something I noticed while using diffrax, was that the adaptive solvers where much slower using pmaping integration with the new runtime on CPUs (pmaping is used over sharding for this reason #26586). I adapted the code from the aforementioned issue to also show that pmap-ing is slower on the new runtime.
with 0.4.31
Elapsed time: 0.002367 seconds
with 0.4.33
Elapsed time: 0.015590 seconds
with 0.5.0
Elapsed time: 0.018911 seconds
This example is of course trivial, but represents the core subroutine of adaptive SDE solvers. Currently this can be solved by disabling the new CPU thunk runtime, but I'm just reporting it so hopefully it can be fixed in the future :).
System info (python version, jaxlib version, accelerator, etc.)
multiple jax version, CPU, tested on Mac and colab
The text was updated successfully, but these errors were encountered: