Skip to content

StepNeverStop/RLs

Repository files navigation

RLs: Reinforcement Learning Algorithm Based On PyTorch.

RLs

This project includes SOTA or classic reinforcement learning (single and multi-agent) algorithms used for training agents by interacting with Unity through ml-agents Release 18 or with gym.

About

The goal of this framework is to provide stable implementations of standard RL algorithms and simultaneously enable fast prototyping of new methods. It aims to fill the need for a small, easily grokked codebase in which users can freely experiment with wild ideas (speculative research).

Characteristics

This project supports:

  • Suitable for Windows, Linux, and OSX
  • Single- and Multi-Agent training.
  • Multiple type of observation sensors as input.
  • Only need 3 steps to implement a new algorithm:
    1. policy write .py in rls/algorithms/{single/multi} directory and make the policy inherit from super-class defined in rls/algorithms/base
    2. config write .yaml in rls/configs/algorithms/ directory and specify the super config type defined in rls/configs/algorithms/general.yaml
    3. register register new algorithm in rls/algorithms/__init__.py
  • Only need 3 steps to adapt to a new training environment:
    1. wrapper write environment wrappers in rls/envs/{new platform} directory and make it inherit from super-class defined in rls/envs/env_base.py
    2. config write default configuration in rls/configs/{new platform}
    3. register register new environment platform in rls/envs/__init__.py
  • Compatible with several environment platforms
    • Unity3D ml-agents.
    • PettingZoo
    • gym, for now only two data types are compatibleβ€”β€”[Box, Discrete]. Support parallel training using gym envs, just need to specify --copies to how many agents you want to train in parallel.
      • environments:
      • observation -> action:
        • Discrete -> Discrete (observation type -> action type)
        • Discrete -> Box
        • Box -> Discrete
        • Box -> Box
        • Box/Discrete -> Tuple(Discrete, Discrete, Discrete)
  • Four types of Replay Buffer, Default is ER:
  • Noisy Net for better exploration.
  • Intrinsic Curiosity Module for almost all off-policy algorithms implemented.
  • Parallel training multiple scenes for Gym
  • Unified data format

Installation

method 1:

$ git clone https://github.com/StepNeverStop/RLs.git
$ cd RLs
$ conda create -n rls python=3.8
$ conda activate rls
# Windows
$ pip install -e .[windows]
# Linux or Mac OS
$ pip install -e .

method 1:

conda env create -f environment.yaml

If using ml-agents:

$ pip install -e .[unity]

You can download the builded docker image from here:

$ docker pull keavnn/rls:latest

If anyone who wants to send a PR, plz format all code-files first:

$ pip install -e .[pr]
$ python auto_format.py -d ./

Implemented Algorithms

For now, these algorithms are available:

Algorithms Discrete Continuous Image RNN Command parameter
PG βœ“ βœ“ βœ“ βœ“ pg
AC βœ“ βœ“ βœ“ βœ“ ac
A2C βœ“ βœ“ βœ“ βœ“ a2c
NPG βœ“ βœ“ βœ“ βœ“ npg
TRPO βœ“ βœ“ βœ“ βœ“ trpo
PPO βœ“ βœ“ βœ“ βœ“ ppo
DQN βœ“ βœ“ βœ“ dqn
Double DQN βœ“ βœ“ βœ“ ddqn
Dueling Double DQN βœ“ βœ“ βœ“ dddqn
Averaged DQN βœ“ βœ“ βœ“ averaged_dqn
Bootstrapped DQN βœ“ βœ“ βœ“ bootstrappeddqn
Soft Q-Learning βœ“ βœ“ βœ“ sql
C51 βœ“ βœ“ βœ“ c51
QR-DQN βœ“ βœ“ βœ“ qrdqn
IQN βœ“ βœ“ βœ“ iqn
Rainbow βœ“ βœ“ βœ“ rainbow
DPG βœ“ βœ“ βœ“ βœ“ dpg
DDPG βœ“ βœ“ βœ“ βœ“ ddpg
TD3 βœ“ βœ“ βœ“ βœ“ td3
SAC(has V network) βœ“ βœ“ βœ“ βœ“ sac_v
SAC βœ“ βœ“ βœ“ βœ“ sac
TAC sac βœ“ βœ“ βœ“ tac
MaxSQN βœ“ βœ“ βœ“ maxsqn
OC βœ“ βœ“ βœ“ βœ“ oc
AOC βœ“ βœ“ βœ“ βœ“ aoc
PPOC βœ“ βœ“ βœ“ βœ“ ppoc
IOC βœ“ βœ“ βœ“ βœ“ ioc
PlaNet βœ“ βœ“ 1 planet
Dreamer βœ“ βœ“ βœ“ 1 dreamer
DreamerV2 βœ“ βœ“ βœ“ 1 dreamerv2
VDN βœ“ βœ“ βœ“ vdn
QMIX βœ“ βœ“ βœ“ qmix
Qatten βœ“ βœ“ βœ“ qatten
QPLEX βœ“ βœ“ βœ“ qplex
QTRAN βœ“ βœ“ βœ“ qtran
MADDPG βœ“ βœ“ βœ“ βœ“ maddpg
MASAC βœ“ βœ“ βœ“ βœ“ masac
CQL βœ“ βœ“ βœ“ cql_dqn
BCQ βœ“ βœ“ βœ“ βœ“ bcq
MVE βœ“ βœ“ mve

1 means must use rnn or rnn is used by default.

Getting started

"""
usage: run.py [-h] [-c COPIES] [--seed SEED] [-r]
              [-p {gym,unity,pettingzoo}]
              [-a {maddpg,masac,vdn,qmix,qatten,qtran,qplex,aoc,ppoc,oc,ioc,planet,dreamer,dreamerv2,mve,cql_dqn,bcq,pg,npg,trpo,ppo,a2c,ac,dpg,ddpg,td3,sac_v,sac,tac,dqn,ddqn,dddqn,averaged_dqn,c51,qrdqn,rainbow,iqn,maxsqn,sql,bootstrappeddqn}]
              [-i] [-l LOAD_PATH] [-m MODELS] [-n NAME]
              [--config-file CONFIG_FILE] [--store-dir STORE_DIR]
              [--episode-length EPISODE_LENGTH] [--hostname] [-e ENV_NAME]
              [-f FILE_NAME] [-s] [-d DEVICE] [-t MAX_TRAIN_STEP]

optional arguments:
  -h, --help            show this help message and exit
  -c COPIES, --copies COPIES
                        nums of environment copies that collect data in
                        parallel
  --seed SEED           specify the random seed of module random, numpy and
                        pytorch
  -r, --render          whether render game interface
  -p {gym,unity,pettingzoo}, --platform {gym,unity,pettingzoo}
                        specify the platform of training environment
  -a {maddpg,masac,vdn,qmix,qatten,qtran,qplex,aoc,ppoc,oc,ioc,planet,dreamer,dreamerv2,mve,cql_dqn,bcq,pg,npg,trpo,ppo,a2c,ac,dpg,ddpg,td3,sac_v,sac,tac,dqn,ddqn,dddqn,averaged_dqn,c51,qrdqn,rainbow,iqn,maxsqn,sql,bootstrappeddqn}, --algorithm {maddpg,masac,vdn,qmix,qatten,qtran,qplex,aoc,ppoc,oc,ioc,planet,dreamer,dreamerv2,mve,cql_dqn,bcq,pg,npg,trpo,ppo,a2c,ac,dpg,ddpg,td3,sac_v,sac,tac,dqn,ddqn,dddqn,averaged_dqn,c51,qrdqn,rainbow,iqn,maxsqn,sql,bootstrappeddqn}
                        specify the training algorithm
  -i, --inference       inference the trained model, not train policies
  -l LOAD_PATH, --load-path LOAD_PATH
                        specify the name of pre-trained model that need to
                        load
  -m MODELS, --models MODELS
                        specify the number of trails that using different
                        random seeds
  -n NAME, --name NAME  specify the name of this training task
  --config-file CONFIG_FILE
                        specify the path of training configuration file
  --store-dir STORE_DIR
                        specify the directory that store model, log and
                        others
  --episode-length EPISODE_LENGTH
                        specify the maximum step per episode
  --hostname            whether concatenate hostname with the training name
  -e ENV_NAME, --env-name ENV_NAME
                        specify the environment name
  -f FILE_NAME, --file-name FILE_NAME
                        specify the path of builded training environment of
                        UNITY3D
  -s, --save            specify whether save models/logs/summaries while
                        training or not
  -d DEVICE, --device DEVICE
                        specify the device that operate Torch.Tensor
  -t MAX_TRAIN_STEP, --max-train-step MAX_TRAIN_STEP
                        specify the maximum training steps
"""

Example:

python run.py -s    # save model and log while train
python run.py -p gym -a dqn -e CartPole-v0 -c 12 -n dqn_cartpole
python run.py -p unity -a ppo -n run_with_unity -c 1

The main training loop of pseudo-code in this repo is as:

# noinspection PyUnresolvedReferences
agent.episode_reset()  # initialize rnn hidden state or something else
# noinspection PyUnresolvedReferences
obs = env.reset()
while True:
    # noinspection PyUnresolvedReferences
    env_rets = env.step(agent(obs))
    # noinspection PyUnresolvedReferences
    agent.episode_step(obs, env_rets)  # store experience, save model, and train off-policy algorithms
    obs = env_rets['obs']
    if env_rets['done']:
        break
# noinspection PyUnresolvedReferences
agent.episode_end()  # train on-policy algorithms

Giving credit

If using this repository for your research, please cite:

@misc{RLs,
  author = {Keavnn},
  title = {RLs: A Featureless Reinforcement Learning Repository},
  year = {2019},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/StepNeverStop/RLs}},
}

Issues

Any questions/errors about this project, please let me know in here.