Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Running JAX demo script #1

Open
glegarda opened this issue Apr 23, 2024 · 4 comments
Open

Running JAX demo script #1

glegarda opened this issue Apr 23, 2024 · 4 comments

Comments

@glegarda
Copy link

Hello,

I followed the JAX set up instructions and tried to run the demo script, but I obtained the following error:

Traceback (most recent call last):
  File "src/demo_nolearning.py", line 8, in <module>
    from utils import initialize_meta_params, get_default_inits, run_single_simulation, str2bool
  File "/home/guillermo/Code/git-projects/collective_motion_actinf/jax/src/utils.py", line 15, in <module>
    from genprocess import get_observations, get_observations_special, advance_positions, init_gen_process, compute_Dgroup_and_rankings_t, compute_Dgroup_and_rankings_vmapped, compute_turning_magnitudes, compute_integrated_change_magnitude
  File "/home/guillermo/Code/git-projects/collective_motion_actinf/jax/src/genprocess/__init__.py", line 1, in <module>
    from .geometry import *
  File "/home/guillermo/Code/git-projects/collective_motion_actinf/jax/src/genprocess/geometry.py", line 3, in <module>
    from jax_md import space
  File "/home/guillermo/.venv/jax/lib/python3.8/site-packages/jax_md/__init__.py", line 16, in <module>
    from jax_md import energy
  File "/home/guillermo/.venv/jax/lib/python3.8/site-packages/jax_md/energy.py", line 28, in <module>
    import haiku as hk
  File "/home/guillermo/.venv/jax/lib/python3.8/site-packages/haiku/__init__.py", line 19, in <module>
    from haiku import data_structures
  File "/home/guillermo/.venv/jax/lib/python3.8/site-packages/haiku/data_structures.py", line 18, in <module>
    from haiku._src.data_structures import to_haiku_dict
  File "/home/guillermo/.venv/jax/lib/python3.8/site-packages/haiku/_src/data_structures.py", line 30, in <module>
    from haiku._src import utils
  File "/home/guillermo/.venv/jax/lib/python3.8/site-packages/haiku/_src/utils.py", line 42, in <module>
    def auto_repr(cls: type[Any], *args, **kwargs) -> str:
TypeError: 'type' object is not subscriptable

I am working on an Ubuntu 20.04.6 LTS x86_64 machine with an NVIDIA GeForce RTX 3060 and Python 3.8.10, and I tried both the GPU and CPU versions of JAX, but the error remains.

Any clue as to what might be going on? Perhaps some version compatibility issue?

Thanks!

@arnauqb
Copy link

arnauqb commented May 9, 2024

you need to upgrade to python>=3.9, python3.8 does not suport this kind of type annotation.

@conorheins
Copy link
Owner

Thanks @arnauqb for stepping in and helping. And apologies @glegarda for not being more specific about Python and JAX versions. I will go back and annotate the versions of each requirement more rigorously once I get a chance.

@glegarda
Copy link
Author

Thank you both for your help! I had to do some tinkering, but eventually I got it the example working. First, I upgraded to Python 3.11. This got rid of the original error, but prompted some others due to further compatibility issues. In case this helps you annotate the required versions, @conorheins, these are the ones I had to install manually/reinstall:

  • jax 0.4.19: the JAX version 0.4.28 installed by default throws an error because the 'KeyArray' attribute of 'jax.random' has been deprecated since version 0.4.16 and was removed in 0.4.24. Version 0.4.19 is the first version of JAX that is also compatible with the rest of the packages installed (the package jax-md in the requirements installs flax 0.8.3, which requires JAX version >= 0.4.19). Note that I had to reinstall JAX after installing the requirements, as installing jax-md forces the (re)installation of JAX 0.4.28.
  • scipy 1.12.0: the latest SciPy version (1.13.0) throws an error because the attribute 'tril' was removed from 'scipy.linalg'.
  • PyQt5: I had to also install this one in order to display the figure from matplotlib

With these modifications, I was able to run the example.

@conorheins
Copy link
Owner

conorheins commented May 14, 2024

Thanks a lot for documenting this so thoroughly @glegarda. Good to know about the deprecation of the KeyArray attribute in newer versions of JAX. JAX's experimental development status means that these deprecations/lack of reverse-compatibility unfortunately spring up frustratingly often. So I'll either (A) freeze the requirements to an earlier version of jax (like 0.4.19) that is before 0.4.24 while still being new enough to be compatible with the remaining packages like jax-md, flax 0.8.3, etc), or (B) I'll just update the code to be consistent with latest versions of jax like 0.4.24 and greater.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants