Skip to content

Latest commit

 

History

History
57 lines (43 loc) · 1.02 KB

jax.rst

File metadata and controls

57 lines (43 loc) · 1.02 KB
.. currentmodule:: jax

jax package

Subpackages

.. toctree::
    :maxdepth: 1

    jax.numpy
    jax.scipy
    jax.experimental
    jax.lax
    jax.ops
    jax.random
    jax.tree_util

Just-in-time compilation (jit)

.. autofunction:: jit
.. autofunction:: disable_jit
.. autofunction:: xla_computation
.. autofunction:: make_jaxpr
.. autofunction:: eval_shape

Automatic differentiation

.. autofunction:: grad
.. autofunction:: value_and_grad
.. autofunction:: jacfwd
.. autofunction:: jacrev
.. autofunction:: hessian
.. autofunction:: jvp
.. autofunction:: linearize
.. autofunction:: vjp
.. autofunction:: custom_transforms
.. autofunction:: defjvp
.. autofunction:: defjvp_all
.. autofunction:: defvjp
.. autofunction:: defvjp_all
.. autofunction:: custom_gradient


Vectorization (vmap)

.. autofunction:: vmap


Parallelization (pmap)

.. autofunction:: pmap