Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
XinJingHao authored Jan 8, 2024
1 parent 2c36bc5 commit 1a7c6b3
Show file tree
Hide file tree
Showing 12 changed files with 435 additions and 0 deletions.
150 changes: 150 additions & 0 deletions 2.4_Categorical-DQN_C51/Categorical_DQN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import torch.nn as nn
import numpy as np
import torch
import copy


def build_net(layer_shape, activation, output_activation):
'''build net with for loop'''
layers = []
for j in range(len(layer_shape)-1):
act = activation if j < len(layer_shape)-2 else output_activation
layers += [nn.Linear(layer_shape[j], layer_shape[j+1]), act()]
return nn.Sequential(*layers)

class Categorical_Q_Net(nn.Module):
def __init__(self, state_dim, action_dim, hid_shape, atoms):
super(Categorical_Q_Net, self).__init__()
self.atoms = atoms
self.n_atoms = len(atoms)
self.action_dim = action_dim

layers = [state_dim] + list(hid_shape) + [action_dim*self.n_atoms]
self.net = build_net(layers, nn.ReLU, nn.Identity)

def _predict(self, state):
logits = self.net(state) # (batch_size, action_dim*n_atoms)
distributions = torch.softmax(logits.view(len(state), self.action_dim, self.n_atoms), dim=2) # (batch_size, a_dim, n_atoms)
q_values = (distributions * self.atoms).sum(2) # (batch_size, a_dim)
return distributions, q_values

def forward(self, state, action=None):
distributions, q_values = self._predict(state)
if action is None:
action = torch.argmax(q_values, dim=1)
return action, distributions[torch.arange(len(state)), action]




class CDQN_agent(object):
def __init__(self, **kwargs):
# Init hyperparameters for agent, just like "self.gamma = opt.gamma, self.lambd = opt.lambd, ..."
self.__dict__.update(kwargs)
self.atoms = torch.linspace(self.v_min, self.v_max, steps=self.n_atoms,device=self.dvc)
self.delta_z = (self.v_max - self.v_min) / (self.n_atoms - 1)
self.m = torch.zeros((self.batch_size, self.n_atoms), device=self.dvc)

self.q_net = Categorical_Q_Net(self.state_dim, self.action_dim, (self.net_width,self.net_width),self.atoms).to(self.dvc)
self.q_net_optimizer = torch.optim.Adam(self.q_net.parameters(), lr=self.lr)
self.q_target = copy.deepcopy(self.q_net)
# Freeze target networks with respect to optimizers (only update via polyak averaging)
for p in self.q_target.parameters(): p.requires_grad = False

self.offset = torch.linspace(0, (self.batch_size - 1) * self.n_atoms, self.batch_size,device=self.dvc).unsqueeze(-1).long()
self.replay_buffer = ReplayBuffer(self.state_dim, self.dvc, max_size=int(1e6))
self.tau = 0.005

def select_action(self, state, deterministic):
# Only be used when interacting with the env
with torch.no_grad():
state = torch.FloatTensor(state.reshape(1, -1)).to(self.dvc)
if (not deterministic) and (np.random.rand() < self.exp_noise):
return np.random.randint(0,self.action_dim)
else:
a, _ = self.q_net(state)
return a.cpu().item()


def train(self):
s, a, r, s_next, dw = self.replay_buffer.sample(self.batch_size) # dw(terminate): die or win

'''Compute the target distribution:'''
with torch.no_grad():
# Note that the original paper just use Single Q-learning, but we find Double Q-learning more stable.
if self.DQL:
argmax_a_next, _ = self.q_net(s_next) # (batch_size,)
_, batched_next_distribution = self.q_target(s_next, argmax_a_next) # _, (batch_size, n_atoms) # Double Q-learning
else:
_, batched_next_distribution = self.q_target(s_next) # _, (batch_size, n_atoms) # Single Q-learning

self.m *= 0
t_z = (r + (~dw) * self.gamma * self.atoms).clamp(self.v_min, self.v_max) # (batch_size, n_atoms)
b = (t_z - self.v_min)/self.delta_z # b∈[0,n_atoms-1]; shape: (batch_size, n_atoms)
l = b.floor().long() # (batch_size, n_atoms)
u = b.ceil().long() # (batch_size, n_atoms)

# When bj is exactly an integer, then bj.floor() == bj.ceil(), then u should +1.
# Eg: bj=1, l=1, u should = 2
delta_m_l = (u + (l == u) - b) * batched_next_distribution # (batch_size, n_atoms)
delta_m_u = (b - l) * batched_next_distribution # (batch_size, n_atoms)


'''Distribute probability with tensor operation. Much more faster than the For loop in the original paper.'''
self.m.view(-1).index_add_(0, (l + self.offset).view(-1), delta_m_l.view(-1))
self.m.view(-1).index_add_(0, (u + self.offset).view(-1), delta_m_u.view(-1))

# Get current estimate:
_, batched_distribution = self.q_net(s, a.flatten()) # _, (batch_size, n_atoms)

# Compute Corss Entropy Loss:
# q_loss = (-(self.m * batched_distribution.log()).sum(-1)).mean() # Original Cross Entropy loss, not stable
q_loss = (-(self.m * batched_distribution.clamp(min=1e-5, max=1 - 1e-5).log()).sum(-1)).mean() # more stable

self.q_net_optimizer.zero_grad()
q_loss.backward()
self.q_net_optimizer.step()

# Update the frozen target models
for param, target_param in zip(self.q_net.parameters(), self.q_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)


def save(self,algo,EnvName,steps):
torch.save(self.q_net.state_dict(), "./model/{}_{}_{}k.pth".format(algo,EnvName,steps))

def load(self,algo,EnvName,steps):
self.q_net.load_state_dict(torch.load("./model/{}_{}_{}k.pth".format(algo,EnvName,steps)))
self.q_target.load_state_dict(torch.load("./model/{}_{}_{}k.pth".format(algo,EnvName,steps)))


class ReplayBuffer(object):
def __init__(self, state_dim, dvc, max_size=int(1e6)):
self.max_size = max_size
self.dvc = dvc
self.ptr = 0
self.size = 0

self.s = torch.zeros((max_size, state_dim),dtype=torch.float,device=self.dvc)
self.a = torch.zeros((max_size, 1),dtype=torch.long,device=self.dvc)
self.r = torch.zeros((max_size, 1),dtype=torch.float,device=self.dvc)
self.s_next = torch.zeros((max_size, state_dim),dtype=torch.float,device=self.dvc)
self.dw = torch.zeros((max_size, 1),dtype=torch.bool,device=self.dvc)

def add(self, s, a, r, s_next, dw):
self.s[self.ptr] = torch.from_numpy(s).to(self.dvc)
self.a[self.ptr] = a
self.r[self.ptr] = r
self.s_next[self.ptr] = torch.from_numpy(s_next).to(self.dvc)
self.dw[self.ptr] = dw

self.ptr = (self.ptr + 1) % self.max_size
self.size = min(self.size + 1, self.max_size)

def sample(self, batch_size):
ind = torch.randint(0, self.size, device=self.dvc, size=(batch_size,))
return self.s[ind], self.a[ind], self.r[ind], self.s_next[ind], self.dw[ind]




Binary file added 2.4_Categorical-DQN_C51/Images/lld.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions 2.4_Categorical-DQN_C51/Images/training_curve.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
21 changes: 21 additions & 0 deletions 2.4_Categorical-DQN_C51/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2021 XinJingHao

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
75 changes: 75 additions & 0 deletions 2.4_Categorical-DQN_C51/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# C51: Categorical-DQN-Pytorch
A **clean and robust Pytorch implementation of Categorical DQN(C51)**

Render | Training curve
:-----------------------:|:-----------------------:|
<img src="https://github.com/XinJingHao/C51-Categorical-DQN-Pytorch/blob/main/Images/lld.gif" width="80%" height="auto"> | <img src="https://github.com/XinJingHao/C51-Categorical-DQN-Pytorch/blob/main/Images/training_curve.svg" width="100%" height="auto">

**Other RL algorithms by Pytorch can be found [here](https://github.com/XinJingHao/RL-Algorithms-by-Pytorch).**



## Dependencies
```python
gymnasium==0.29.1
matplotlib==3.8.2
numpy==1.26.1
pytorch==2.1.0

python==3.11.5
```

## How to use my code
### Train from scratch
```bash
python main.py
```
where the default enviroment is 'CartPole'.

### Change Enviroment
If you want to train on different enviroments, just run:
```bash
python main.py --EnvIdex 1
```

The --EnvIdex can be set to be 0 and 1, where
```bash
'--EnvIdex 0' for 'CartPole-v1'
'--EnvIdex 1' for 'LunarLander-v2'
```

Note: if you want to play on LunarLander, you need to install [box2d-py](https://gymnasium.farama.org/environments/box2d/) first. You can install box2d-py via: ```pip install gymnasium[box2d]```


### Play with trained model
```bash
python main.py --EnvIdex 0 --render True --Loadmodel True --ModelIdex 60 # Play with CartPole
```
```bash
python main.py --EnvIdex 1 --render True --Loadmodel True --ModelIdex 320 # Play with LunarLander
```

### Visualize the training curve
You can use the [tensorboard](https://pytorch.org/docs/stable/tensorboard.html) to record anv visualize the training curve.

- Installation (please make sure Pytorch is installed already):
```bash
pip install tensorboard
pip install packaging
```
- Record (the training curves will be saved at '**\runs**'):
```bash
python main.py --write True
```

- Visualization:
```bash
tensorboard --logdir runs
```


### Hyperparameter Setting
For more details of Hyperparameter Setting, please check 'main.py'

### References
[Bellemare M G, Dabney W, Munos R. A distributional perspective on reinforcement learning[C]//International conference on machine learning. PMLR, 2017: 449-458.](https://proceedings.mlr.press/v70/bellemare17a/bellemare17a.pdf)
128 changes: 128 additions & 0 deletions 2.4_Categorical-DQN_C51/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from utils import evaluate_policy, render_policy, str2bool
from datetime import datetime
from Categorical_DQN import CDQN_agent
import gymnasium as gym
import os, shutil
import argparse
import torch


'''Hyperparameter Setting'''
parser = argparse.ArgumentParser()
parser.add_argument('--dvc', type=str, default='cuda', help='running device: cuda or cpu')
parser.add_argument('--EnvIdex', type=int, default=0, help='CP-v1, LLd-v2')
parser.add_argument('--write', type=str2bool, default=False, help='Use SummaryWriter to record the training')
parser.add_argument('--render', type=str2bool, default=False, help='Render or Not')
parser.add_argument('--Loadmodel', type=str2bool, default=False, help='Load pretrained model or Not')
parser.add_argument('--ModelIdex', type=int, default=320, help='which model to load')

parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--Max_train_steps', type=int, default=int(400e3), help='Max training steps')
parser.add_argument('--save_interval', type=int, default=int(20e3), help='Model saving interval, in steps.')
parser.add_argument('--eval_interval', type=int, default=int(2e3), help='Model evaluating interval, in steps.')
parser.add_argument('--random_steps', type=int, default=int(3e3), help='steps for random policy to explore')
parser.add_argument('--update_every', type=int, default=50, help='training frequency')

parser.add_argument('--gamma', type=float, default=0.99, help='Discounted Factor')
parser.add_argument('--net_width', type=int, default=200, help='Hidden net width')
parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
parser.add_argument('--batch_size', type=int, default=256, help='lenth of sliced trajectory')
parser.add_argument('--exp_noise', type=float, default=0.2, help='explore noise')
parser.add_argument('--noise_decay', type=float, default=0.99, help='decay rate of explore noise')

parser.add_argument('--DQL', type=str2bool, default=True, help='Whether to use Double Q-learning')
parser.add_argument('--v_min', type=float, default=-100, help='Vmin')
parser.add_argument('--v_max', type=float, default=100, help='Vmax')
parser.add_argument('--n_atoms', type=int, default=51, help='number of atoms')

opt = parser.parse_args()
opt.dvc = torch.device(opt.dvc) # from str to torch.device
print(opt)


def main():
EnvName = ['CartPole-v1','LunarLander-v2']
BriefEnvName = ['CPV1', 'LLdV2']
env = gym.make(EnvName[opt.EnvIdex], render_mode = "human" if opt.render else None)
eval_env = gym.make(EnvName[opt.EnvIdex])
opt.state_dim = env.observation_space.shape[0]
opt.action_dim = env.action_space.n
opt.max_e_steps = env._max_episode_steps
opt.action_info = {0: ['Left', 'Right'], 1: ['Noop', 'LeftEngine', 'MainEngine', 'RightEngine']}
algo_name = 'C51_' + 'DDQN' if opt.DQL else 'DQN'

# Seed Everything
env_seed = opt.seed
torch.manual_seed(opt.seed)
torch.cuda.manual_seed(opt.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
print("Random Seed: {}".format(opt.seed))

print('Algorithm:',algo_name,' Env:',BriefEnvName[opt.EnvIdex],' state_dim:',opt.state_dim,
' action_dim:',opt.action_dim,' Random Seed:',opt.seed, ' max_e_steps:',opt.max_e_steps, '\n')

if opt.write:
from torch.utils.tensorboard import SummaryWriter
timenow = str(datetime.now())[0:-10]
timenow = ' ' + timenow[0:13] + '_' + timenow[-2::]
writepath = 'runs/{}_{}'.format(algo_name,BriefEnvName[opt.EnvIdex]) + timenow
if os.path.exists(writepath): shutil.rmtree(writepath)
writer = SummaryWriter(log_dir=writepath)

#Build model and replay buffer
if not os.path.exists('model'): os.mkdir('model')
agent = CDQN_agent(**vars(opt))
if opt.Loadmodel: agent.load(algo_name,BriefEnvName[opt.EnvIdex],opt.ModelIdex)

if opt.render:
render_policy(env, agent, opt)
else:
total_steps = 0
while total_steps < opt.Max_train_steps:
s, info = env.reset(seed=env_seed) # Do not use opt.seed directly, or it can overfit to opt.seed
env_seed += 1
done = False

'''Interact & trian'''
while not done:
#e-greedy exploration
if total_steps < opt.random_steps: a = env.action_space.sample()
else: a = agent.select_action(s, deterministic=False)
s_next, r, dw, tr, info = env.step(a) # dw: dead&win; tr: truncated
done = (dw or tr)

agent.replay_buffer.add(s, a, r, s_next, dw)
s = s_next

'''update if its time'''
# train 50 times every 50 steps rather than 1 training per step. Better!
if total_steps >= opt.random_steps and total_steps % opt.update_every == 0:
for j in range(opt.update_every): agent.train()

'''record & log'''
if total_steps % opt.eval_interval == 0:
agent.exp_noise *= opt.noise_decay
score = evaluate_policy(eval_env, agent, turns = 3)
if opt.write:
writer.add_scalar('ep_r', score, global_step=total_steps)
writer.add_scalar('noise', agent.exp_noise, global_step=total_steps)
print('EnvName:',BriefEnvName[opt.EnvIdex],'seed:',opt.seed,'steps: {}k'.format(int(total_steps/1000)),'score:', int(score))
total_steps += 1

'''save model'''
if total_steps % opt.save_interval == 0:
agent.save(algo_name,BriefEnvName[opt.EnvIdex],int(total_steps/1000))
env.close()
eval_env.close()

if __name__ == '__main__':
main()








Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit 1a7c6b3

Please sign in to comment.