Skip to content

Commit

Permalink
Commit E2O
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaokai committed Jun 16, 2023
1 parent 56b9680 commit 4cdac9d
Show file tree
Hide file tree
Showing 124 changed files with 23,779 additions and 0 deletions.
50 changes: 50 additions & 0 deletions offline-rl-algorithms/E2O/E2O-offline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import argparse
import d3rlpy_new.d3rlpy
from sklearn.model_selection import train_test_split


def main():

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='halfcheetah-medium-expert-v2')
parser.add_argument('--n_critic', type=int, default=10)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--gpu', type=int, default=0)
args = parser.parse_args()

dataset, env = d3rlpy_new.d3rlpy.datasets.get_dataset(args.dataset)

d3rlpy_new.d3rlpy.seed(args.seed)
env.seed(args.seed)

_, test_episodes = train_test_split(dataset, test_size=0.2)

encoder = d3rlpy_new.d3rlpy.models.encoders.VectorEncoderFactory([256, 256, 256])

e2o = d3rlpy_new.d3rlpy.algos.CQL(actor_learning_rate=3e-5,
critic_learning_rate=3e-4,
temp_learning_rate=1e-4,
actor_encoder_factory=encoder,
critic_encoder_factory=encoder,
batch_size=256,
n_action_samples=10,
alpha_learning_rate=0.0,
alpha_threshold=5.0,
conservative_weight=5.0,
n_critics=args.n_critic,
use_gpu=args.gpu)

e2o.fit(dataset.episodes,
eval_episodes=test_episodes,
n_steps=1000000,
n_steps_per_epoch=1000,
save_interval=1000,
scorers={
'environment': d3rlpy_new.d3rlpy.metrics.evaluate_on_environment(env),
'value_scale': d3rlpy_new.d3rlpy.metrics.average_value_estimation_scorer,
},
experiment_name=f"E2O-CQL-{args.n_critic}_{args.dataset}_{args.seed}")


if __name__ == '__main__':
main()
51 changes: 51 additions & 0 deletions offline-rl-algorithms/E2O/E2O-online.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import argparse
import gym
import d3rlpy_new.d3rlpy
import os


def main():

parser = argparse.ArgumentParser()
parser.add_argument('--env', type=str, default='HalfCheetah-v2')
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--gpu', type=int, default=0)
args = parser.parse_args()

env = gym.make(args.env)
eval_env = gym.make(args.env)

# fix seed
d3rlpy_new.d3rlpy.seed(args.seed)
env.seed(args.seed)
eval_env.seed(args.seed)

file_name = 'E2O-CQL-10_halfcheetah-medium-expert-v2_' + str(args.seed)
file_path = "./d3rlpy_logs"
filelist = os.listdir(file_path)
for file in filelist:
if file_name in file:
file_name = file
e2o = d3rlpy_new.d3rlpy.algos.E2O.from_json(
'd3rlpy_logs/' + file_name + '/params.json',
use_gpu=args.gpu
)
e2o.load_model(
'd3rlpy_logs/' + file_name + '/model_1000000.pt'
)

buffer = d3rlpy_new.d3rlpy.online.buffers.ReplayBuffer(maxlen=1000000, env=env)

e2o.fit_online(env,
buffer,
eval_env=eval_env,
n_steps=250000,
n_steps_per_epoch=1000,
update_interval=1,
update_start_step=1000,
save_interval=1000000,
experiment_name=f"E2O-CQL-10_online_{args.env}-medium-expert_seed{args.seed}")


if __name__ == '__main__':
main()
78 changes: 78 additions & 0 deletions offline-rl-algorithms/E2O/PEX-main/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Policy Expansion (PEX) [ICLR 2023]

## Installation
The training environment (PyTorch and dependencies) can be installed as follows:

```
git clone https://github.com/Haichao-Zhang/PEX.git
cd PEX
python3 -m venv .venv_pex
source .venv_pex/bin/activate
pip3 install -e .
```

## Train

### Offline Training

Set ```root_dir``` to the path where the experimental results will be saved.

Then run:

```
CUDA_VISIABLE_DEVICES=0 python main_offline.py --log_dir=$root_dir/antmaze-large-play-v0_offline_run1 --env_name antmaze-large-play-v0 --tau 0.9 --beta 10.0
```

### Online Training
First set the path to the offline checkpoint:
```
path_to_offline_ckpt=$root_dir/antmaze-large-play-v0_offline_run1/offline_ckpt
```

and select an algorithm:
```
algorithm=pex (or any other algorithms in [scratch, direct, buffer, pex])
```

and then run
```
CUDA_VISIABLE_DEVICES=0 python ./main_online.py --log_dir=$root_dir/antmaze-large-play-v0_run1_$algorithm --env_name=antmaze-large-play-v0 --tau 0.9 --beta 10.0 --ckpt_path=$path_to_offline_ckpt --eval_episode_num=100 --algorithm=$algorithm
```


### Example on Locomotion Task

```
CUDA_VISIABLE_DEVICES=0 python main_offline.py --log_dir=$root_dir/halfcheetah-random-v2_offline_run1 --env_name halfcheetah-random-v2 --tau 0.9 --beta 10.0
path_to_offline_ckpt=$root_dir/halfcheetah-random-v2/offline_ckpt
CUDA_VISIABLE_DEVICES=0 python ./main_online.py --log_dir=$root_dir/halfcheetah-random-v2_run1_$algorithm --env_name=halfcheetah-random-v2 --tau 0.9 --beta 10.0 --ckpt_path=$path_to_offline_ckpt --eval_episode_num=10 --algorithm=$algorithm
```


## Paper

<b>[Policy Expansion for Bridging Offline-to-Online Reinforcement Learning](https://arxiv.org/pdf/2302.00935.pdf)</b> <br>

[Haichao Zhang](https://sites.google.com/site/hczhang1/),
Wei Xu,
Haonan Yu

*International Conference on Learning Representations* (ICLR), 2023



## Cite

Please cite our work if you find it useful:

```
@inproceedings{PEX,
author = {Haichao Zhang and Wei Xu and Haonan Yu},
title = {Policy Expansion for Bridging Offline-to-Online Reinforcement Learning},
booktitle = {International Conference on Learning Representations ({ICLR})},
year = {2023},
}
```
77 changes: 77 additions & 0 deletions offline-rl-algorithms/E2O/PEX-main/main_offline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import os
from pathlib import Path
import torch
from tqdm import trange

from pex.algorithms.iql import IQL
from pex.networks.policy import GaussianPolicy
from pex.networks.value_functions import DoubleCriticNetwork, ValueNetwork
from pex.utils.util import (
set_seed, DEFAULT_DEVICE, sample_batch,
eval_policy, set_default_device, get_env_and_dataset)


def main(args):
torch.set_num_threads(1)
if os.path.exists(args.log_dir):
print(f"The directory {args.log_dir} exists. Please specify a different one.")
return
else:
print(f"Creating directory {args.log_dir}")
os.mkdir(args.log_dir)

env, dataset, _ = get_env_and_dataset(args.env_name, args.max_episode_steps)
obs_dim = dataset['observations'].shape[1]
act_dim = dataset['actions'].shape[1]

if args.seed is not None:
set_seed(args.seed, env=env)

if torch.cuda.is_available():
set_default_device()

action_space = env.action_space
policy = GaussianPolicy(obs_dim, act_dim, hidden_dim=args.hidden_dim, n_hidden=args.hidden_num, action_space=action_space, scale_distribution=False, state_dependent_std=False)

iql = IQL(
critic=DoubleCriticNetwork(obs_dim, act_dim, hidden_dim=args.hidden_dim, n_hidden=args.hidden_num),
vf=ValueNetwork(obs_dim, hidden_dim=args.hidden_dim, n_hidden=args.hidden_num),
policy=policy,
optimizer_ctor=lambda params: torch.optim.Adam(params, lr=args.learning_rate),
max_steps=args.num_steps,
tau=args.tau,
beta=args.beta,
target_update_rate=args.target_update_rate,
discount=args.discount
)

for step in trange(args.num_steps):
iql.update(**sample_batch(dataset, args.batch_size))
if (step + 1) % args.eval_period == 0:
eval_policy(env, args.env_name, iql, args.max_episode_steps, args.eval_episode_num, args.seed)

torch.save(iql.state_dict(), args.log_dir + '/offline_ckpt')


if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('--env_name', required=True)
parser.add_argument('--log_dir', required=True)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--discount', type=float, default=0.99)
parser.add_argument('--hidden_dim', type=int, default=256)
parser.add_argument('--hidden_num', type=int, default=2)
parser.add_argument('--num_steps', type=int, default=1000001, metavar='N',
help='maximum number of training steps (default: 1000000)')
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--learning_rate', type=float, default=3e-4)
parser.add_argument('--target_update_rate', type=float, default=0.005)
parser.add_argument('--tau', type=float, default=0.7)
parser.add_argument('--beta', type=float, default=10.0,
help='IQL inverse temperature')
parser.add_argument('--eval_period', type=int, default=1000)
parser.add_argument('--eval_episode_num', type=int, default=100,
help='Number of evaluation episodes (default: 10)')
parser.add_argument('--max_episode_steps', type=int, default=1000)
main(parser.parse_args())
Loading

0 comments on commit 4cdac9d

Please sign in to comment.