forked from TJU-DRL-LAB/AI-Optimizer
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
zhaokai
committed
Jun 16, 2023
1 parent
56b9680
commit 4cdac9d
Showing
124 changed files
with
23,779 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import argparse | ||
import d3rlpy_new.d3rlpy | ||
from sklearn.model_selection import train_test_split | ||
|
||
|
||
def main(): | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--dataset', type=str, default='halfcheetah-medium-expert-v2') | ||
parser.add_argument('--n_critic', type=int, default=10) | ||
parser.add_argument('--seed', type=int, default=1) | ||
parser.add_argument('--gpu', type=int, default=0) | ||
args = parser.parse_args() | ||
|
||
dataset, env = d3rlpy_new.d3rlpy.datasets.get_dataset(args.dataset) | ||
|
||
d3rlpy_new.d3rlpy.seed(args.seed) | ||
env.seed(args.seed) | ||
|
||
_, test_episodes = train_test_split(dataset, test_size=0.2) | ||
|
||
encoder = d3rlpy_new.d3rlpy.models.encoders.VectorEncoderFactory([256, 256, 256]) | ||
|
||
e2o = d3rlpy_new.d3rlpy.algos.CQL(actor_learning_rate=3e-5, | ||
critic_learning_rate=3e-4, | ||
temp_learning_rate=1e-4, | ||
actor_encoder_factory=encoder, | ||
critic_encoder_factory=encoder, | ||
batch_size=256, | ||
n_action_samples=10, | ||
alpha_learning_rate=0.0, | ||
alpha_threshold=5.0, | ||
conservative_weight=5.0, | ||
n_critics=args.n_critic, | ||
use_gpu=args.gpu) | ||
|
||
e2o.fit(dataset.episodes, | ||
eval_episodes=test_episodes, | ||
n_steps=1000000, | ||
n_steps_per_epoch=1000, | ||
save_interval=1000, | ||
scorers={ | ||
'environment': d3rlpy_new.d3rlpy.metrics.evaluate_on_environment(env), | ||
'value_scale': d3rlpy_new.d3rlpy.metrics.average_value_estimation_scorer, | ||
}, | ||
experiment_name=f"E2O-CQL-{args.n_critic}_{args.dataset}_{args.seed}") | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import argparse | ||
import gym | ||
import d3rlpy_new.d3rlpy | ||
import os | ||
|
||
|
||
def main(): | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--env', type=str, default='HalfCheetah-v2') | ||
parser.add_argument('--seed', type=int, default=1) | ||
parser.add_argument('--gpu', type=int, default=0) | ||
args = parser.parse_args() | ||
|
||
env = gym.make(args.env) | ||
eval_env = gym.make(args.env) | ||
|
||
# fix seed | ||
d3rlpy_new.d3rlpy.seed(args.seed) | ||
env.seed(args.seed) | ||
eval_env.seed(args.seed) | ||
|
||
file_name = 'E2O-CQL-10_halfcheetah-medium-expert-v2_' + str(args.seed) | ||
file_path = "./d3rlpy_logs" | ||
filelist = os.listdir(file_path) | ||
for file in filelist: | ||
if file_name in file: | ||
file_name = file | ||
e2o = d3rlpy_new.d3rlpy.algos.E2O.from_json( | ||
'd3rlpy_logs/' + file_name + '/params.json', | ||
use_gpu=args.gpu | ||
) | ||
e2o.load_model( | ||
'd3rlpy_logs/' + file_name + '/model_1000000.pt' | ||
) | ||
|
||
buffer = d3rlpy_new.d3rlpy.online.buffers.ReplayBuffer(maxlen=1000000, env=env) | ||
|
||
e2o.fit_online(env, | ||
buffer, | ||
eval_env=eval_env, | ||
n_steps=250000, | ||
n_steps_per_epoch=1000, | ||
update_interval=1, | ||
update_start_step=1000, | ||
save_interval=1000000, | ||
experiment_name=f"E2O-CQL-10_online_{args.env}-medium-expert_seed{args.seed}") | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Policy Expansion (PEX) [ICLR 2023] | ||
|
||
## Installation | ||
The training environment (PyTorch and dependencies) can be installed as follows: | ||
|
||
``` | ||
git clone https://github.com/Haichao-Zhang/PEX.git | ||
cd PEX | ||
python3 -m venv .venv_pex | ||
source .venv_pex/bin/activate | ||
pip3 install -e . | ||
``` | ||
|
||
## Train | ||
|
||
### Offline Training | ||
|
||
Set ```root_dir``` to the path where the experimental results will be saved. | ||
|
||
Then run: | ||
|
||
``` | ||
CUDA_VISIABLE_DEVICES=0 python main_offline.py --log_dir=$root_dir/antmaze-large-play-v0_offline_run1 --env_name antmaze-large-play-v0 --tau 0.9 --beta 10.0 | ||
``` | ||
|
||
### Online Training | ||
First set the path to the offline checkpoint: | ||
``` | ||
path_to_offline_ckpt=$root_dir/antmaze-large-play-v0_offline_run1/offline_ckpt | ||
``` | ||
|
||
and select an algorithm: | ||
``` | ||
algorithm=pex (or any other algorithms in [scratch, direct, buffer, pex]) | ||
``` | ||
|
||
and then run | ||
``` | ||
CUDA_VISIABLE_DEVICES=0 python ./main_online.py --log_dir=$root_dir/antmaze-large-play-v0_run1_$algorithm --env_name=antmaze-large-play-v0 --tau 0.9 --beta 10.0 --ckpt_path=$path_to_offline_ckpt --eval_episode_num=100 --algorithm=$algorithm | ||
``` | ||
|
||
|
||
### Example on Locomotion Task | ||
|
||
``` | ||
CUDA_VISIABLE_DEVICES=0 python main_offline.py --log_dir=$root_dir/halfcheetah-random-v2_offline_run1 --env_name halfcheetah-random-v2 --tau 0.9 --beta 10.0 | ||
path_to_offline_ckpt=$root_dir/halfcheetah-random-v2/offline_ckpt | ||
CUDA_VISIABLE_DEVICES=0 python ./main_online.py --log_dir=$root_dir/halfcheetah-random-v2_run1_$algorithm --env_name=halfcheetah-random-v2 --tau 0.9 --beta 10.0 --ckpt_path=$path_to_offline_ckpt --eval_episode_num=10 --algorithm=$algorithm | ||
``` | ||
|
||
|
||
## Paper | ||
|
||
<b>[Policy Expansion for Bridging Offline-to-Online Reinforcement Learning](https://arxiv.org/pdf/2302.00935.pdf)</b> <br> | ||
|
||
[Haichao Zhang](https://sites.google.com/site/hczhang1/), | ||
Wei Xu, | ||
Haonan Yu | ||
|
||
*International Conference on Learning Representations* (ICLR), 2023 | ||
|
||
|
||
|
||
## Cite | ||
|
||
Please cite our work if you find it useful: | ||
|
||
``` | ||
@inproceedings{PEX, | ||
author = {Haichao Zhang and Wei Xu and Haonan Yu}, | ||
title = {Policy Expansion for Bridging Offline-to-Online Reinforcement Learning}, | ||
booktitle = {International Conference on Learning Representations ({ICLR})}, | ||
year = {2023}, | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import os | ||
from pathlib import Path | ||
import torch | ||
from tqdm import trange | ||
|
||
from pex.algorithms.iql import IQL | ||
from pex.networks.policy import GaussianPolicy | ||
from pex.networks.value_functions import DoubleCriticNetwork, ValueNetwork | ||
from pex.utils.util import ( | ||
set_seed, DEFAULT_DEVICE, sample_batch, | ||
eval_policy, set_default_device, get_env_and_dataset) | ||
|
||
|
||
def main(args): | ||
torch.set_num_threads(1) | ||
if os.path.exists(args.log_dir): | ||
print(f"The directory {args.log_dir} exists. Please specify a different one.") | ||
return | ||
else: | ||
print(f"Creating directory {args.log_dir}") | ||
os.mkdir(args.log_dir) | ||
|
||
env, dataset, _ = get_env_and_dataset(args.env_name, args.max_episode_steps) | ||
obs_dim = dataset['observations'].shape[1] | ||
act_dim = dataset['actions'].shape[1] | ||
|
||
if args.seed is not None: | ||
set_seed(args.seed, env=env) | ||
|
||
if torch.cuda.is_available(): | ||
set_default_device() | ||
|
||
action_space = env.action_space | ||
policy = GaussianPolicy(obs_dim, act_dim, hidden_dim=args.hidden_dim, n_hidden=args.hidden_num, action_space=action_space, scale_distribution=False, state_dependent_std=False) | ||
|
||
iql = IQL( | ||
critic=DoubleCriticNetwork(obs_dim, act_dim, hidden_dim=args.hidden_dim, n_hidden=args.hidden_num), | ||
vf=ValueNetwork(obs_dim, hidden_dim=args.hidden_dim, n_hidden=args.hidden_num), | ||
policy=policy, | ||
optimizer_ctor=lambda params: torch.optim.Adam(params, lr=args.learning_rate), | ||
max_steps=args.num_steps, | ||
tau=args.tau, | ||
beta=args.beta, | ||
target_update_rate=args.target_update_rate, | ||
discount=args.discount | ||
) | ||
|
||
for step in trange(args.num_steps): | ||
iql.update(**sample_batch(dataset, args.batch_size)) | ||
if (step + 1) % args.eval_period == 0: | ||
eval_policy(env, args.env_name, iql, args.max_episode_steps, args.eval_episode_num, args.seed) | ||
|
||
torch.save(iql.state_dict(), args.log_dir + '/offline_ckpt') | ||
|
||
|
||
if __name__ == '__main__': | ||
from argparse import ArgumentParser | ||
parser = ArgumentParser() | ||
parser.add_argument('--env_name', required=True) | ||
parser.add_argument('--log_dir', required=True) | ||
parser.add_argument('--seed', type=int, default=1) | ||
parser.add_argument('--discount', type=float, default=0.99) | ||
parser.add_argument('--hidden_dim', type=int, default=256) | ||
parser.add_argument('--hidden_num', type=int, default=2) | ||
parser.add_argument('--num_steps', type=int, default=1000001, metavar='N', | ||
help='maximum number of training steps (default: 1000000)') | ||
parser.add_argument('--batch_size', type=int, default=256) | ||
parser.add_argument('--learning_rate', type=float, default=3e-4) | ||
parser.add_argument('--target_update_rate', type=float, default=0.005) | ||
parser.add_argument('--tau', type=float, default=0.7) | ||
parser.add_argument('--beta', type=float, default=10.0, | ||
help='IQL inverse temperature') | ||
parser.add_argument('--eval_period', type=int, default=1000) | ||
parser.add_argument('--eval_episode_num', type=int, default=100, | ||
help='Number of evaluation episodes (default: 10)') | ||
parser.add_argument('--max_episode_steps', type=int, default=1000) | ||
main(parser.parse_args()) |
Oops, something went wrong.