Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hacky way of reducing test memory usage. #59

Merged
merged 2 commits into from
Feb 10, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Hacky way of reducing test memory usage.
  • Loading branch information
patrick-kidger committed Feb 9, 2022
commit b9e2fd9b07e57209dc156ad3851a9cc665501e89
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
strategy:
matrix:
python-version: [ 3.7, 3.8, 3.9 ]
os: [ macOS-latest ]
os: [ ubuntu-latest ]
fail-fast: false
runs-on: ${{ matrix.os }}
steps:
Expand Down
21 changes: 21 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import gc
import random
import sys

import jax.config
import jax.random as jrandom
import psutil
import pytest


Expand All @@ -15,3 +18,21 @@ def _getkey():
return jrandom.PRNGKey(random.randint(0, 2**31 - 1))

return _getkey


# Hugely hacky way of reducing memory usage in tests.
# JAX can be a little over-happy with its caching; this is especially noticable when
# performing tests and therefore doing an unusual amount of compilation etc.
# This can be enough to exceed the 8GB RAM available to Ubuntu instances on GitHub
# Actions.
@pytest.fixture(autouse=True)
def clear_caches():
process = psutil.Process()
if process.memory_info().vms > 4 * 2**30: # >4GB memory usage
for module_name, module in sys.modules.items():
if module_name.startswith("jax"):
for obj_name in dir(module):
obj = getattr(module, obj_name)
if hasattr(obj, "cache_clear"):
obj.cache_clear()
gc.collect()