This is an attempt to recreate many reinforcement learning algorithms in Jax(Flax) world as single-file implementations.
Clone the project
git clone https://github.com/MyNameIsArko/RL-Flax
Go to the project directory
cd RL-Flax
Install basic dependencies
pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install flax tensorflow-probability
Install environment specific dependencies
Run the algorithm you want!
-
DQN
-
Rainbow DQN
-
A2C
-
A3C
Any kind of contribution is welcome!
If you know a little bit of Jax+Flax and know ins and outs of some algorithm then make a pull request. I'll gladly accept it as this is a big project for one man.