Skip to content

Matrix-free numerical linear algebra including trace and log-determinant estimation.

License

Notifications You must be signed in to change notification settings

Weitheskmt/matfree

 
 

Repository files navigation

matfree: Matrix-free linear algebra in JAX

Actions status image image image

Randomised and deterministic matrix-free methods for trace estimation, matrix functions, and/or matrix factorisations. Builds on JAX.

  • ⚡ A stand-alone implementation of stochastic Lanczos quadrature
  • ⚡ Stochastic trace estimation including batching, control variates, and uncertainty quantification
  • ⚡ Matrix-free matrix decompositions for large sparse eigenvalue problems

and many other things. Everything is natively compatible with JAX' feature set: JIT compilation, automatic differentiation, vectorisation, and pytrees. Let us know what you think about matfree!

Installation | Minimal example | Tutorials | Contributing | API docs

Installation

To install the package, run

pip install matfree

Important: This assumes you already have a working installation of JAX. To install JAX, follow these instructions. To combine matfree with a CPU version of JAX, run

pip install matfree[cpu]

which is equivalent to combining pip install jax[cpu] with pip install matfree. (But do not only use matfree on CPU!)

Minimal example

Import matfree and JAX, and set up a test problem.

>>> import jax
>>> import jax.numpy as jnp
>>> from matfree import hutchinson, montecarlo, slq

>>> A = jnp.reshape(jnp.arange(12.0), (6, 2))
>>>
>>> def matvec(x):
...     return A.T @ (A @ x)
...

Estimate the trace of the matrix:

>>> key = jax.random.PRNGKey(1)
>>> normal = montecarlo.normal(shape=(2,))
>>> trace = hutchinson.trace(matvec, key=key, sample_fun=normal)
>>>
>>> print(jnp.round(trace))
514.0
>>>
>>> # for comparison:
>>> print(jnp.round(jnp.trace(A.T @ A)))
506.0

Adjust the batch-size to improve the performance

  • More, smaller batches reduce memory but increase the runtime.
  • Fewer, larger batches increase memory but reduce the runtime.

Change the number of batches as follows:

>>> trace = hutchinson.trace(matvec, key=key, sample_fun=normal, num_batches=10)
>>> print(jnp.round(trace))
508.0
>>>
>>> # for comparison:
>>> print(jnp.round(jnp.trace(A.T @ A)))
506.0

Tutorials

Here are some more advanced tutorials:

  • Control variates: Use control variates and multilevel schemes to reduce variances. (LINK)
  • Log-determinants: Use stochastic Lanczos quadrature to compute matrix functions. (LINK)
  • Higher moments and UQ: Compute means, variances, and other moments simultaneously. (LINK)
  • Vector calculus: Use matrix-free linear algebra to implement vector calculus. (LINK)
  • Pytree-valued states: Combining neural-network Jacobians with stochastic Lanczos quadrature. (LINK)

Let us know what you use matfree for!

Contributing

Contributions are absolutely welcome!

Issues:

Most contributions start with an issue. Please don't hesitate to create issues in which you ask for features, give feedback on performances, or simply want to reach out.

Pull requests:

To make a pull request, proceed as follows:

  • Fork the repository.
  • Install all dependencies with pip install .[full] or pip install -e .[full].
  • Make your changes.
  • From the root of the project, run the tests via make test, and check out make format and make lint as well. Use the pre-commit hook if you like.

When making a pull request, keep in mind the following (rough) guidelines:

About

Matrix-free numerical linear algebra including trace and log-determinant estimation.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.2%
  • Makefile 0.8%