.. currentmodule:: jax jax package Subpackages .. toctree:: :maxdepth: 1 jax.numpy jax.scipy jax.experimental jax.lax jax.ops jax.random jax.tree_util Just-in-time compilation (jit) .. autofunction:: jit .. autofunction:: disable_jit .. autofunction:: xla_computation .. autofunction:: make_jaxpr .. autofunction:: eval_shape Automatic differentiation .. autofunction:: grad .. autofunction:: value_and_grad .. autofunction:: jacfwd .. autofunction:: jacrev .. autofunction:: hessian .. autofunction:: jvp .. autofunction:: linearize .. autofunction:: vjp .. autofunction:: custom_transforms .. autofunction:: defjvp .. autofunction:: defjvp_all .. autofunction:: defvjp .. autofunction:: defvjp_all .. autofunction:: custom_gradient Vectorization (vmap) .. autofunction:: vmap Parallelization (pmap) .. autofunction:: pmap