forked from XinJingHao/DRL-Pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2c36bc5
commit 1a7c6b3
Showing
12 changed files
with
435 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
import torch.nn as nn | ||
import numpy as np | ||
import torch | ||
import copy | ||
|
||
|
||
def build_net(layer_shape, activation, output_activation): | ||
'''build net with for loop''' | ||
layers = [] | ||
for j in range(len(layer_shape)-1): | ||
act = activation if j < len(layer_shape)-2 else output_activation | ||
layers += [nn.Linear(layer_shape[j], layer_shape[j+1]), act()] | ||
return nn.Sequential(*layers) | ||
|
||
class Categorical_Q_Net(nn.Module): | ||
def __init__(self, state_dim, action_dim, hid_shape, atoms): | ||
super(Categorical_Q_Net, self).__init__() | ||
self.atoms = atoms | ||
self.n_atoms = len(atoms) | ||
self.action_dim = action_dim | ||
|
||
layers = [state_dim] + list(hid_shape) + [action_dim*self.n_atoms] | ||
self.net = build_net(layers, nn.ReLU, nn.Identity) | ||
|
||
def _predict(self, state): | ||
logits = self.net(state) # (batch_size, action_dim*n_atoms) | ||
distributions = torch.softmax(logits.view(len(state), self.action_dim, self.n_atoms), dim=2) # (batch_size, a_dim, n_atoms) | ||
q_values = (distributions * self.atoms).sum(2) # (batch_size, a_dim) | ||
return distributions, q_values | ||
|
||
def forward(self, state, action=None): | ||
distributions, q_values = self._predict(state) | ||
if action is None: | ||
action = torch.argmax(q_values, dim=1) | ||
return action, distributions[torch.arange(len(state)), action] | ||
|
||
|
||
|
||
|
||
class CDQN_agent(object): | ||
def __init__(self, **kwargs): | ||
# Init hyperparameters for agent, just like "self.gamma = opt.gamma, self.lambd = opt.lambd, ..." | ||
self.__dict__.update(kwargs) | ||
self.atoms = torch.linspace(self.v_min, self.v_max, steps=self.n_atoms,device=self.dvc) | ||
self.delta_z = (self.v_max - self.v_min) / (self.n_atoms - 1) | ||
self.m = torch.zeros((self.batch_size, self.n_atoms), device=self.dvc) | ||
|
||
self.q_net = Categorical_Q_Net(self.state_dim, self.action_dim, (self.net_width,self.net_width),self.atoms).to(self.dvc) | ||
self.q_net_optimizer = torch.optim.Adam(self.q_net.parameters(), lr=self.lr) | ||
self.q_target = copy.deepcopy(self.q_net) | ||
# Freeze target networks with respect to optimizers (only update via polyak averaging) | ||
for p in self.q_target.parameters(): p.requires_grad = False | ||
|
||
self.offset = torch.linspace(0, (self.batch_size - 1) * self.n_atoms, self.batch_size,device=self.dvc).unsqueeze(-1).long() | ||
self.replay_buffer = ReplayBuffer(self.state_dim, self.dvc, max_size=int(1e6)) | ||
self.tau = 0.005 | ||
|
||
def select_action(self, state, deterministic): | ||
# Only be used when interacting with the env | ||
with torch.no_grad(): | ||
state = torch.FloatTensor(state.reshape(1, -1)).to(self.dvc) | ||
if (not deterministic) and (np.random.rand() < self.exp_noise): | ||
return np.random.randint(0,self.action_dim) | ||
else: | ||
a, _ = self.q_net(state) | ||
return a.cpu().item() | ||
|
||
|
||
def train(self): | ||
s, a, r, s_next, dw = self.replay_buffer.sample(self.batch_size) # dw(terminate): die or win | ||
|
||
'''Compute the target distribution:''' | ||
with torch.no_grad(): | ||
# Note that the original paper just use Single Q-learning, but we find Double Q-learning more stable. | ||
if self.DQL: | ||
argmax_a_next, _ = self.q_net(s_next) # (batch_size,) | ||
_, batched_next_distribution = self.q_target(s_next, argmax_a_next) # _, (batch_size, n_atoms) # Double Q-learning | ||
else: | ||
_, batched_next_distribution = self.q_target(s_next) # _, (batch_size, n_atoms) # Single Q-learning | ||
|
||
self.m *= 0 | ||
t_z = (r + (~dw) * self.gamma * self.atoms).clamp(self.v_min, self.v_max) # (batch_size, n_atoms) | ||
b = (t_z - self.v_min)/self.delta_z # b∈[0,n_atoms-1]; shape: (batch_size, n_atoms) | ||
l = b.floor().long() # (batch_size, n_atoms) | ||
u = b.ceil().long() # (batch_size, n_atoms) | ||
|
||
# When bj is exactly an integer, then bj.floor() == bj.ceil(), then u should +1. | ||
# Eg: bj=1, l=1, u should = 2 | ||
delta_m_l = (u + (l == u) - b) * batched_next_distribution # (batch_size, n_atoms) | ||
delta_m_u = (b - l) * batched_next_distribution # (batch_size, n_atoms) | ||
|
||
|
||
'''Distribute probability with tensor operation. Much more faster than the For loop in the original paper.''' | ||
self.m.view(-1).index_add_(0, (l + self.offset).view(-1), delta_m_l.view(-1)) | ||
self.m.view(-1).index_add_(0, (u + self.offset).view(-1), delta_m_u.view(-1)) | ||
|
||
# Get current estimate: | ||
_, batched_distribution = self.q_net(s, a.flatten()) # _, (batch_size, n_atoms) | ||
|
||
# Compute Corss Entropy Loss: | ||
# q_loss = (-(self.m * batched_distribution.log()).sum(-1)).mean() # Original Cross Entropy loss, not stable | ||
q_loss = (-(self.m * batched_distribution.clamp(min=1e-5, max=1 - 1e-5).log()).sum(-1)).mean() # more stable | ||
|
||
self.q_net_optimizer.zero_grad() | ||
q_loss.backward() | ||
self.q_net_optimizer.step() | ||
|
||
# Update the frozen target models | ||
for param, target_param in zip(self.q_net.parameters(), self.q_target.parameters()): | ||
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) | ||
|
||
|
||
def save(self,algo,EnvName,steps): | ||
torch.save(self.q_net.state_dict(), "./model/{}_{}_{}k.pth".format(algo,EnvName,steps)) | ||
|
||
def load(self,algo,EnvName,steps): | ||
self.q_net.load_state_dict(torch.load("./model/{}_{}_{}k.pth".format(algo,EnvName,steps))) | ||
self.q_target.load_state_dict(torch.load("./model/{}_{}_{}k.pth".format(algo,EnvName,steps))) | ||
|
||
|
||
class ReplayBuffer(object): | ||
def __init__(self, state_dim, dvc, max_size=int(1e6)): | ||
self.max_size = max_size | ||
self.dvc = dvc | ||
self.ptr = 0 | ||
self.size = 0 | ||
|
||
self.s = torch.zeros((max_size, state_dim),dtype=torch.float,device=self.dvc) | ||
self.a = torch.zeros((max_size, 1),dtype=torch.long,device=self.dvc) | ||
self.r = torch.zeros((max_size, 1),dtype=torch.float,device=self.dvc) | ||
self.s_next = torch.zeros((max_size, state_dim),dtype=torch.float,device=self.dvc) | ||
self.dw = torch.zeros((max_size, 1),dtype=torch.bool,device=self.dvc) | ||
|
||
def add(self, s, a, r, s_next, dw): | ||
self.s[self.ptr] = torch.from_numpy(s).to(self.dvc) | ||
self.a[self.ptr] = a | ||
self.r[self.ptr] = r | ||
self.s_next[self.ptr] = torch.from_numpy(s_next).to(self.dvc) | ||
self.dw[self.ptr] = dw | ||
|
||
self.ptr = (self.ptr + 1) % self.max_size | ||
self.size = min(self.size + 1, self.max_size) | ||
|
||
def sample(self, batch_size): | ||
ind = torch.randint(0, self.size, device=self.dvc, size=(batch_size,)) | ||
return self.s[ind], self.a[ind], self.r[ind], self.s_next[ind], self.dw[ind] | ||
|
||
|
||
|
||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2021 XinJingHao | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# C51: Categorical-DQN-Pytorch | ||
A **clean and robust Pytorch implementation of Categorical DQN(C51)** : | ||
|
||
Render | Training curve | ||
:-----------------------:|:-----------------------:| | ||
<img src="https://github.com/XinJingHao/C51-Categorical-DQN-Pytorch/blob/main/Images/lld.gif" width="80%" height="auto"> | <img src="https://github.com/XinJingHao/C51-Categorical-DQN-Pytorch/blob/main/Images/training_curve.svg" width="100%" height="auto"> | ||
|
||
**Other RL algorithms by Pytorch can be found [here](https://github.com/XinJingHao/RL-Algorithms-by-Pytorch).** | ||
|
||
|
||
|
||
## Dependencies | ||
```python | ||
gymnasium==0.29.1 | ||
matplotlib==3.8.2 | ||
numpy==1.26.1 | ||
pytorch==2.1.0 | ||
|
||
python==3.11.5 | ||
``` | ||
|
||
## How to use my code | ||
### Train from scratch | ||
```bash | ||
python main.py | ||
``` | ||
where the default enviroment is 'CartPole'. | ||
|
||
### Change Enviroment | ||
If you want to train on different enviroments, just run: | ||
```bash | ||
python main.py --EnvIdex 1 | ||
``` | ||
|
||
The --EnvIdex can be set to be 0 and 1, where | ||
```bash | ||
'--EnvIdex 0' for 'CartPole-v1' | ||
'--EnvIdex 1' for 'LunarLander-v2' | ||
``` | ||
|
||
Note: if you want to play on LunarLander, you need to install [box2d-py](https://gymnasium.farama.org/environments/box2d/) first. You can install box2d-py via: ```pip install gymnasium[box2d]``` | ||
|
||
|
||
### Play with trained model | ||
```bash | ||
python main.py --EnvIdex 0 --render True --Loadmodel True --ModelIdex 60 # Play with CartPole | ||
``` | ||
```bash | ||
python main.py --EnvIdex 1 --render True --Loadmodel True --ModelIdex 320 # Play with LunarLander | ||
``` | ||
|
||
### Visualize the training curve | ||
You can use the [tensorboard](https://pytorch.org/docs/stable/tensorboard.html) to record anv visualize the training curve. | ||
|
||
- Installation (please make sure Pytorch is installed already): | ||
```bash | ||
pip install tensorboard | ||
pip install packaging | ||
``` | ||
- Record (the training curves will be saved at '**\runs**'): | ||
```bash | ||
python main.py --write True | ||
``` | ||
|
||
- Visualization: | ||
```bash | ||
tensorboard --logdir runs | ||
``` | ||
|
||
|
||
### Hyperparameter Setting | ||
For more details of Hyperparameter Setting, please check 'main.py' | ||
|
||
### References | ||
[Bellemare M G, Dabney W, Munos R. A distributional perspective on reinforcement learning[C]//International conference on machine learning. PMLR, 2017: 449-458.](https://proceedings.mlr.press/v70/bellemare17a/bellemare17a.pdf) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
from utils import evaluate_policy, render_policy, str2bool | ||
from datetime import datetime | ||
from Categorical_DQN import CDQN_agent | ||
import gymnasium as gym | ||
import os, shutil | ||
import argparse | ||
import torch | ||
|
||
|
||
'''Hyperparameter Setting''' | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--dvc', type=str, default='cuda', help='running device: cuda or cpu') | ||
parser.add_argument('--EnvIdex', type=int, default=0, help='CP-v1, LLd-v2') | ||
parser.add_argument('--write', type=str2bool, default=False, help='Use SummaryWriter to record the training') | ||
parser.add_argument('--render', type=str2bool, default=False, help='Render or Not') | ||
parser.add_argument('--Loadmodel', type=str2bool, default=False, help='Load pretrained model or Not') | ||
parser.add_argument('--ModelIdex', type=int, default=320, help='which model to load') | ||
|
||
parser.add_argument('--seed', type=int, default=0, help='random seed') | ||
parser.add_argument('--Max_train_steps', type=int, default=int(400e3), help='Max training steps') | ||
parser.add_argument('--save_interval', type=int, default=int(20e3), help='Model saving interval, in steps.') | ||
parser.add_argument('--eval_interval', type=int, default=int(2e3), help='Model evaluating interval, in steps.') | ||
parser.add_argument('--random_steps', type=int, default=int(3e3), help='steps for random policy to explore') | ||
parser.add_argument('--update_every', type=int, default=50, help='training frequency') | ||
|
||
parser.add_argument('--gamma', type=float, default=0.99, help='Discounted Factor') | ||
parser.add_argument('--net_width', type=int, default=200, help='Hidden net width') | ||
parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate') | ||
parser.add_argument('--batch_size', type=int, default=256, help='lenth of sliced trajectory') | ||
parser.add_argument('--exp_noise', type=float, default=0.2, help='explore noise') | ||
parser.add_argument('--noise_decay', type=float, default=0.99, help='decay rate of explore noise') | ||
|
||
parser.add_argument('--DQL', type=str2bool, default=True, help='Whether to use Double Q-learning') | ||
parser.add_argument('--v_min', type=float, default=-100, help='Vmin') | ||
parser.add_argument('--v_max', type=float, default=100, help='Vmax') | ||
parser.add_argument('--n_atoms', type=int, default=51, help='number of atoms') | ||
|
||
opt = parser.parse_args() | ||
opt.dvc = torch.device(opt.dvc) # from str to torch.device | ||
print(opt) | ||
|
||
|
||
def main(): | ||
EnvName = ['CartPole-v1','LunarLander-v2'] | ||
BriefEnvName = ['CPV1', 'LLdV2'] | ||
env = gym.make(EnvName[opt.EnvIdex], render_mode = "human" if opt.render else None) | ||
eval_env = gym.make(EnvName[opt.EnvIdex]) | ||
opt.state_dim = env.observation_space.shape[0] | ||
opt.action_dim = env.action_space.n | ||
opt.max_e_steps = env._max_episode_steps | ||
opt.action_info = {0: ['Left', 'Right'], 1: ['Noop', 'LeftEngine', 'MainEngine', 'RightEngine']} | ||
algo_name = 'C51_' + 'DDQN' if opt.DQL else 'DQN' | ||
|
||
# Seed Everything | ||
env_seed = opt.seed | ||
torch.manual_seed(opt.seed) | ||
torch.cuda.manual_seed(opt.seed) | ||
torch.backends.cudnn.deterministic = True | ||
torch.backends.cudnn.benchmark = False | ||
print("Random Seed: {}".format(opt.seed)) | ||
|
||
print('Algorithm:',algo_name,' Env:',BriefEnvName[opt.EnvIdex],' state_dim:',opt.state_dim, | ||
' action_dim:',opt.action_dim,' Random Seed:',opt.seed, ' max_e_steps:',opt.max_e_steps, '\n') | ||
|
||
if opt.write: | ||
from torch.utils.tensorboard import SummaryWriter | ||
timenow = str(datetime.now())[0:-10] | ||
timenow = ' ' + timenow[0:13] + '_' + timenow[-2::] | ||
writepath = 'runs/{}_{}'.format(algo_name,BriefEnvName[opt.EnvIdex]) + timenow | ||
if os.path.exists(writepath): shutil.rmtree(writepath) | ||
writer = SummaryWriter(log_dir=writepath) | ||
|
||
#Build model and replay buffer | ||
if not os.path.exists('model'): os.mkdir('model') | ||
agent = CDQN_agent(**vars(opt)) | ||
if opt.Loadmodel: agent.load(algo_name,BriefEnvName[opt.EnvIdex],opt.ModelIdex) | ||
|
||
if opt.render: | ||
render_policy(env, agent, opt) | ||
else: | ||
total_steps = 0 | ||
while total_steps < opt.Max_train_steps: | ||
s, info = env.reset(seed=env_seed) # Do not use opt.seed directly, or it can overfit to opt.seed | ||
env_seed += 1 | ||
done = False | ||
|
||
'''Interact & trian''' | ||
while not done: | ||
#e-greedy exploration | ||
if total_steps < opt.random_steps: a = env.action_space.sample() | ||
else: a = agent.select_action(s, deterministic=False) | ||
s_next, r, dw, tr, info = env.step(a) # dw: dead&win; tr: truncated | ||
done = (dw or tr) | ||
|
||
agent.replay_buffer.add(s, a, r, s_next, dw) | ||
s = s_next | ||
|
||
'''update if its time''' | ||
# train 50 times every 50 steps rather than 1 training per step. Better! | ||
if total_steps >= opt.random_steps and total_steps % opt.update_every == 0: | ||
for j in range(opt.update_every): agent.train() | ||
|
||
'''record & log''' | ||
if total_steps % opt.eval_interval == 0: | ||
agent.exp_noise *= opt.noise_decay | ||
score = evaluate_policy(eval_env, agent, turns = 3) | ||
if opt.write: | ||
writer.add_scalar('ep_r', score, global_step=total_steps) | ||
writer.add_scalar('noise', agent.exp_noise, global_step=total_steps) | ||
print('EnvName:',BriefEnvName[opt.EnvIdex],'seed:',opt.seed,'steps: {}k'.format(int(total_steps/1000)),'score:', int(score)) | ||
total_steps += 1 | ||
|
||
'''save model''' | ||
if total_steps % opt.save_interval == 0: | ||
agent.save(algo_name,BriefEnvName[opt.EnvIdex],int(total_steps/1000)) | ||
env.close() | ||
eval_env.close() | ||
|
||
if __name__ == '__main__': | ||
main() | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
Binary file not shown.
Binary file not shown.
Binary file added
BIN
+2.84 KB
...DQN_C51/runs/C51_DDQN_CPV1 2024-01-08 14_56/events.out.tfevents.1704696993.3060SG.43170.0
Binary file not shown.
Binary file added
BIN
+3.37 KB
...DQN_C51/runs/C51_DDQN_CPV1 2024-01-08 15_03/events.out.tfevents.1704697439.3060SG.43661.0
Binary file not shown.
Binary file added
BIN
+17.5 KB
...QN_C51/runs/C51_DDQN_LLdV2 2024-01-08 15_06/events.out.tfevents.1704697567.3060SG.43782.0
Binary file not shown.
Oops, something went wrong.