Randomised and deterministic matrix-free methods for trace estimation, matrix functions, and/or matrix factorisations. Builds on JAX.
- ⚡ A stand-alone implementation of stochastic Lanczos quadrature
- ⚡ Stochastic trace estimation including batching, control variates, and uncertainty quantification
- ⚡ Matrix-free matrix decompositions for large sparse eigenvalue problems
and many other things. Everything is natively compatible with JAX' feature set: JIT compilation, automatic differentiation, vectorisation, and pytrees. Let us know what you think about matfree!
Installation | Minimal example | Tutorials | Contributing | API docs
To install the package, run
pip install matfree
Important: This assumes you already have a working installation of JAX.
To install JAX, follow these instructions.
To combine matfree
with a CPU version of JAX, run
pip install matfree[cpu]
which is equivalent to combining pip install jax[cpu]
with pip install matfree
.
(But do not only use matfree on CPU!)
Import matfree and JAX, and set up a test problem.
>>> import jax
>>> import jax.numpy as jnp
>>> from matfree import hutchinson, montecarlo, slq
>>> A = jnp.reshape(jnp.arange(12.0), (6, 2))
>>>
>>> def matvec(x):
... return A.T @ (A @ x)
...
Estimate the trace of the matrix:
>>> key = jax.random.PRNGKey(1)
>>> normal = montecarlo.normal(shape=(2,))
>>> trace = hutchinson.trace(matvec, key=key, sample_fun=normal)
>>>
>>> print(jnp.round(trace))
514.0
>>>
>>> # for comparison:
>>> print(jnp.round(jnp.trace(A.T @ A)))
506.0
Adjust the batch-size to improve the performance
- More, smaller batches reduce memory but increase the runtime.
- Fewer, larger batches increase memory but reduce the runtime.
Change the number of batches as follows:
>>> trace = hutchinson.trace(matvec, key=key, sample_fun=normal, num_batches=10)
>>> print(jnp.round(trace))
508.0
>>>
>>> # for comparison:
>>> print(jnp.round(jnp.trace(A.T @ A)))
506.0
Here are some more advanced tutorials:
- Control variates: Use control variates and multilevel schemes to reduce variances. (LINK)
- Log-determinants: Use stochastic Lanczos quadrature to compute matrix functions. (LINK)
- Higher moments and UQ: Compute means, variances, and other moments simultaneously. (LINK)
- Vector calculus: Use matrix-free linear algebra to implement vector calculus. (LINK)
- Pytree-valued states: Combining neural-network Jacobians with stochastic Lanczos quadrature. (LINK)
Let us know what you use matfree for!
Contributions are absolutely welcome!
Issues:
Most contributions start with an issue. Please don't hesitate to create issues in which you ask for features, give feedback on performances, or simply want to reach out.
Pull requests:
To make a pull request, proceed as follows:
- Fork the repository.
- Install all dependencies with
pip install .[full]
orpip install -e .[full]
. - Make your changes.
- From the root of the project, run the tests via
make test
, and check outmake format
andmake lint
as well. Use the pre-commit hook if you like.
When making a pull request, keep in mind the following (rough) guidelines:
- Most PRs resolve an issue.
- Most PRs contain a single commit. Here is how we can write better commit messages.
- Most enhancements (e.g. new features) are covered by tests.