Life is a library for reinforce learning implemented by PyTorch.
- Sarsa
- multi-Sarsa
- Q-Learning
- Dyna-Q
- DQN
- Double-DQN
- Dueling-DQN
- REINFORCE策略梯度
- Actor-Critic
- PPO
- DDPG
- SAC
- BC
- GAIL
- CQL
pip install rllife
你可以在PyPI上面下载.gz文件,然后通过本地安装。
pyyaml==6.0
ipykernel==6.15.1
jupyter==1.0.0
matplotlib==3.5.3
seaborn==0.12.1
dill==0.3.5.1
argparse==1.4.0
pandas==1.3.5
pyglet==1.5.26
importlib-metadata<5.0
setuptools==65.2.0
gym==0.25.2
numpy==1.21.6
pandas==1.3.4
torch==1.10.0
tqdm==4.64.1
- 基于目前主流的深度学习框架pytorch,支持gpu加速。
- 简洁易用,仅需寥寥几行代码,即可实现强化学习算法的构建与训练。
- 覆盖面广,从传统的QLearning,到一些最新的强化学习算法都有实现。
- 所有超参均支持自定义,同时可自定义深度神经网络的结构,封装程度低而又简介易用。
- 传统的强化学习算法,如Sarsa;
- 只基于值函数的深度强化学习算法,如DQN;
- 基于策略函数和值函数的深度强化学习算法,如AC;
- 模仿强化学习算法,如BC;
- 离线强化学习算法,如CQL。
训练器的名称和算法的名称是一一对应的,如要训练DQN
,则其训练函数的名称为:
train_dqn
- dqn.py中为传统DQN算法
- dqn_improved.py中为一些改进的DQN算法
- trainer中包含了以上各种dqn算法的训练函数
要使用Life进行强化学习,仅需简单的三步,下面以DQN在CartPole环境上的训练为例进行快速入门:
from life.dqn.dqn import DQN # 导入模型
from life.dqn.trainer import train_dqn # 导入训练器
from life.envs.dis_env_demo import make # 环境的一个例子
from life.utils.replay.replay_buffer import ReplayBuffer # 回放池
import torch
import matplotlib.pyplot as plt
# 设置超参数
lr = 2e-3
num_episodes = 500
hidden_dim = 128
gamma = 0.98
epsilon = 0.01
target_update = 10
buffer_size = 10000
minimal_size = 500
batch_size = 64
device = torch.device("cpu") # 也可指定为gpu : torch.device("cuda")
env=make() # 建立环境,这里为 CartPole-v0
replay_buffer = ReplayBuffer(buffer_size) # 回放池
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
# 建立模型
agent = DQN(state_dim, hidden_dim, action_dim, lr, gamma, epsilon,
target_update, device) # DQN模型
注意,如果你足够细心,你会发现在上述建立DQN的过程中,我们没有传入一个Neural Network,这是因为在建立深度强化学习时,Life提供了一个默认的双层神经网络作为建立DQN的默认神经网络。当然,你也可以使用自己设计的神经网络结构:
class YourNet:
"""your network for your task"""
pass
agent = DQN(state_dim, hidden_dim, action_dim, lr, gamma, epsilon,
target_update, device, q_net=YourNet) # DQN模型
此时,原本用于传递给默认神经网络的超参数state_dim,hidden_dim,action_dim就没有用了,可随意设置。
result=train_dqn(agent,env,replay_buffer,minimal_size,batch_size,num_episodes)
episodes_list = list(range(len(result)))
plt.figure(figsize=(8,6))
plt.plot(episodes_list, result)
plt.xlabel("Episodes")
plt.ylabel("Returns")
plt.title("DQN on {}".format("Cart Pole v1"))
plt.show()
其中,return_list
为:训练过程中每个回合的汇报,agent
为训练好的智能体。
return_agent
默认为False
。
可见,除了超参数的设置之外,我们构建DQN算法只使用了两行代码:
from life.dqn.dqn import DQN
agent = DQN(state_dim, hidden_dim, action_dim, lr, gamma, epsilon,target_update, device)
我们训练DQN同样只使用了两行代码:
from life.dqn.trainer import train_dqn
result=train_dqn(agent,env,replay_buffer,minimal_size,batch_size,num_episodes)