JAX Synergistic Memory Inspector Usage In your JAX script: from jax_smi import initialise_tracking initialise_tracking() # some computation... Open a shell and run: jax-smi