Installation | Examples | References
Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX.
JAXopt can be installed with pip directly from github, with the following command:
$ pip install git+https://github.com/google/jaxopt
Alternatively, it can be be installed from sources with the following command:
$ python setup.py install
Our implicit differentiation framework is described in this paper. To cite it:
@article{jaxopt_implicit_diff,
title={Efficient and Modular Implicit Differentiation},
author={Blondel, Mathieu and Berthet, Quentin and Cuturi, Marco and Frostig, Roy and Hoyer, Stephan and Llinares-L{\'o}pez, Felipe and Pedregosa, Fabian and Vert, Jean-Philippe},
journal={arXiv preprint arXiv:2105.15183},
year={2021}
}
JAXopt is an open source project maintained by a dedicated team in Google Research, but is not an official Google product.