Skip to content

Commit

Permalink
add torch ppo (PaddlePaddle#213)
Browse files Browse the repository at this point in the history
* add ppo

* fix bugs

* yapf
  • Loading branch information
banma12956 authored Jun 11, 2020
1 parent 2c7340f commit 2deefa8
Show file tree
Hide file tree
Showing 10 changed files with 854 additions and 0 deletions.
103 changes: 103 additions & 0 deletions benchmark/torch/ppo/arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import torch


def get_args():
parser = argparse.ArgumentParser(description='RL')
parser.add_argument(
'--lr', type=float, default=3e-4, help='learning rate (default: 3e-4)')
parser.add_argument(
'--eps',
type=float,
default=1e-5,
help='RMSprop optimizer epsilon (default: 1e-5)')
parser.add_argument(
'--gamma',
type=float,
default=0.99,
help='discount factor for rewards (default: 0.99)')
parser.add_argument(
'--gae-lambda',
type=float,
default=0.95,
help='gae lambda parameter (default: 0.95)')
parser.add_argument(
'--entropy-coef',
type=float,
default=0.,
help='entropy term coefficient (default: 0.)')
parser.add_argument(
'--value-loss-coef',
type=float,
default=0.5,
help='value loss coefficient (default: 0.5)')
parser.add_argument(
'--max-grad-norm',
type=float,
default=0.5,
help='max norm of gradients (default: 0.5)')
parser.add_argument(
'--seed', type=int, default=1, help='random seed (default: 1)')
parser.add_argument(
'--num-steps',
type=int,
default=2048,
help='number of maximum forward steps in ppo (default: 2048)')
parser.add_argument(
'--ppo-epoch',
type=int,
default=10,
help='number of ppo epochs (default: 10)')
parser.add_argument(
'--num-mini-batch',
type=int,
default=32,
help='number of batches for ppo (default: 32)')
parser.add_argument(
'--clip-param',
type=float,
default=0.2,
help='ppo clip parameter (default: 0.2)')
parser.add_argument(
'--log-interval',
type=int,
default=1,
help='log interval, one log per n updates (default: 1)')
parser.add_argument(
'--eval-interval',
type=int,
default=10,
help='eval interval, one eval per n updates (default: 10)')
parser.add_argument(
'--num-env-steps',
type=int,
default=10e5,
help='number of environment steps to train (default: 10e5)')
parser.add_argument(
'--env-name',
default='Hopper-v2',
help='environment to train on (default: Hopper-v2)')
parser.add_argument(
'--use-linear-lr-decay',
action='store_true',
default=False,
help='use a linear schedule on the learning rate')
args = parser.parse_args()

args.cuda = torch.cuda.is_available()

return args
56 changes: 56 additions & 0 deletions benchmark/torch/ppo/evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import torch

import utils
from wrapper import make_env


def evaluate(agent, ob_rms, env_name, seed, device):
if seed != None:
seed += 1
eval_envs = make_env(env_name, seed, None)
vec_norm = utils.get_vec_normalize(eval_envs)
if vec_norm is not None:
vec_norm.eval()
vec_norm.ob_rms = ob_rms

eval_episode_rewards = []

obs = eval_envs.reset()
eval_masks = torch.zeros(1, 1, device=device)

while len(eval_episode_rewards) < 10:
with torch.no_grad():
action = agent.predict(obs)

# Obser reward and next obs
obs, _, done, infos = eval_envs.step(action)

eval_masks = torch.tensor(
[[0.0] if done_ else [1.0] for done_ in done],
dtype=torch.float32,
device=device)

for info in infos:
if 'episode' in info.keys():
eval_episode_rewards.append(info['episode']['r'])

eval_envs.close()

print(" Evaluation using {} episodes: mean reward {:.5f}\n".format(
len(eval_episode_rewards), np.mean(eval_episode_rewards)))
return np.mean(eval_episode_rewards)
78 changes: 78 additions & 0 deletions benchmark/torch/ppo/mujoco_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import parl
import torch


class MujocoAgent(parl.Agent):
def __init__(self, algorithm, device):
self.alg = algorithm
self.device = device

def predict(self, obs):
obs = torch.from_numpy(obs).float().to(self.device)
action = self.alg.predict(obs)
return action.cpu().numpy()

def sample(self, obs):
obs = torch.from_numpy(obs).to(self.device)
value, action, action_log_probs = self.alg.sample(obs)
return value.cpu().numpy(), action.cpu().numpy(), \
action_log_probs.cpu().numpy()

def learn(self, next_value, gamma, gae_lambda, ppo_epoch, num_mini_batch,
rollouts):
value_loss_epoch = 0
action_loss_epoch = 0
dist_entropy_epoch = 0

for e in range(ppo_epoch):
data_generator = rollouts.sample_batch(next_value, gamma,
gae_lambda, num_mini_batch)

for sample in data_generator:
obs_batch, actions_batch, \
value_preds_batch, return_batch, old_action_log_probs_batch, \
adv_targ = sample

obs_batch = torch.from_numpy(obs_batch).to('cuda')
actions_batch = torch.from_numpy(actions_batch).to('cuda').to(
'cuda')
value_preds_batch = torch.from_numpy(value_preds_batch).to(
'cuda')
return_batch = torch.from_numpy(return_batch).to('cuda')
old_action_log_probs_batch = torch.from_numpy(
old_action_log_probs_batch).to('cuda')
adv_targ = torch.from_numpy(adv_targ).to('cuda')

value_loss, action_loss, dist_entropy = self.alg.learn(
obs_batch, actions_batch, value_preds_batch, return_batch,
old_action_log_probs_batch, adv_targ)

value_loss_epoch += value_loss
action_loss_epoch += action_loss
dist_entropy_epoch += dist_entropy

num_updates = ppo_epoch * num_mini_batch

value_loss_epoch /= num_updates
action_loss_epoch /= num_updates
dist_entropy_epoch /= num_updates

return value_loss_epoch, action_loss_epoch, dist_entropy_epoch

def value(self, obs):
obs = torch.from_numpy(obs).to(self.device)
return self.alg.value(obs).cpu().numpy()
64 changes: 64 additions & 0 deletions benchmark/torch/ppo/mujoco_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import parl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal


class MujocoModel(parl.Model):
def __init__(self, obs_dim, act_dim):
super(MujocoModel, self).__init__()
self.actor = Actor(obs_dim, act_dim)
self.critic = Critic(obs_dim)

def policy(self, obs):
return self.actor(obs)

def value(self, obs):
return self.critic(obs)


class Actor(parl.Model):
def __init__(self, obs_dim, act_dim):
super(Actor, self).__init__()
self.fc1 = nn.Linear(obs_dim, 64)
self.fc2 = nn.Linear(64, 64)

self.fc_mean = nn.Linear(64, act_dim)
self.log_std = nn.Parameter(torch.zeros(act_dim))

def forward(self, obs):
x = torch.tanh(self.fc1(obs))
x = torch.tanh(self.fc2(x))

mean = self.fc_mean(x)
return mean, self.log_std


class Critic(parl.Model):
def __init__(self, obs_dim):
super(Critic, self).__init__()
self.fc1 = nn.Linear(obs_dim, 64)
self.fc2 = nn.Linear(64, 64)
self.fc3 = nn.Linear(64, 1)

def forward(self, obs):
x = torch.tanh(self.fc1(obs))
x = torch.tanh(self.fc2(x))
value = self.fc3(x)

return value
Loading

0 comments on commit 2deefa8

Please sign in to comment.