Skip to content

Commit

Permalink
[RLlib] Issue 8319 DDPG (MA or num_envs_per_worker > 1) broken. (ray-…
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored May 8, 2020
1 parent 5f278c6 commit d7eaacb
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 55 deletions.
5 changes: 4 additions & 1 deletion rllib/agents/ddpg/tests/test_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ def test_ddpg_compilation(self):
"""Test whether a DDPGTrainer can be built with both frameworks."""
config = ddpg.DEFAULT_CONFIG.copy()
config["num_workers"] = 0 # Run locally.
config["num_envs_per_worker"] = 2 # Run locally.

num_iterations = 2

# Test against all frameworks.
for _ in framework_iterator(config, ("torch", "tf")):
for _ in framework_iterator(config, ("tf", "torch")):
trainer = ddpg.DDPGTrainer(config=config, env="Pendulum-v0")
for i in range(num_iterations):
results = trainer.train()
Expand Down Expand Up @@ -366,6 +367,8 @@ def test_ddpg_loss_function(self):
else:
check(tf_var, torch_var, rtol=0.07)

trainer.stop()

def _get_batch_helper(self, obs_size, actions, batch_size):
return {
SampleBatch.CUR_OBS: np.random.random(size=obs_size),
Expand Down
11 changes: 4 additions & 7 deletions rllib/examples/env/simple_corridor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,21 @@ def __init__(self, config):
self.end_pos = config["corridor_length"]
self.cur_pos = 0
self.action_space = Discrete(2)
self.observation_space = Box(
0.0, self.end_pos, shape=(1, ), dtype=np.float32)
self.observation_space = Box(0.0, 999.0, shape=(1, ), dtype=np.float32)

def set_corridor_length(self, length):
self.end_pos = length
self.observation_space = Box(
0.0, self.end_pos, shape=(1, ), dtype=np.float32)
print("Updated corridor length to {}".format(length))

def reset(self):
self.cur_pos = 0
self.cur_pos = 0.0
return [self.cur_pos]

def step(self, action):
assert action in [0, 1], action
if action == 0 and self.cur_pos > 0:
self.cur_pos -= 1
self.cur_pos -= 1.0
elif action == 1:
self.cur_pos += 1
self.cur_pos += 1.0
done = self.cur_pos >= self.end_pos
return [self.cur_pos], 1 if done else 0, done, {}
10 changes: 10 additions & 0 deletions rllib/models/tf/tf_action_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,11 @@ def _unsquash(self, values):
unsquashed = tf.math.atanh(save_normed_values)
return unsquashed

@staticmethod
@override(ActionDistribution)
def required_model_output_shape(action_space, model_config):
return np.prod(action_space.shape) * 2


class Beta(TFActionDistribution):
"""
Expand Down Expand Up @@ -371,6 +376,11 @@ def _squash(self, raw_values):
def _unsquash(self, values):
return (values - self.low) / (self.high - self.low)

@staticmethod
@override(ActionDistribution)
def required_model_output_shape(action_space, model_config):
return np.prod(action_space.shape) * 2


class Deterministic(TFActionDistribution):
"""Action distribution that returns the input values directly.
Expand Down
61 changes: 33 additions & 28 deletions rllib/tests/test_multi_agent_pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import unittest

import ray
from ray.rllib.examples.env.multi_agent import MultiAgentPendulum
from ray.tune import run_experiments
from ray.tune.registry import register_env
from ray.rllib.examples.env.multi_agent import MultiAgentPendulum
from ray.rllib.utils.test_utils import framework_iterator


class TestMultiAgentPendulum(unittest.TestCase):
Expand All @@ -17,34 +18,38 @@ def tearDown(self) -> None:
def test_multi_agent_pendulum(self):
register_env("multi_agent_pendulum",
lambda _: MultiAgentPendulum({"num_agents": 1}))
trials = run_experiments({
"test": {
"run": "PPO",
"env": "multi_agent_pendulum",
"stop": {
"timesteps_total": 500000,
"episode_reward_mean": -200,
},
"config": {
"train_batch_size": 2048,
"vf_clip_param": 10.0,
"num_workers": 0,
"num_envs_per_worker": 10,
"lambda": 0.1,
"gamma": 0.95,
"lr": 0.0003,
"sgd_minibatch_size": 64,
"num_sgd_iter": 10,
"model": {
"fcnet_hiddens": [128, 128],

# Test for both torch and tf.
for fw in framework_iterator(frameworks=["torch", "tf"]):
trials = run_experiments({
"test": {
"run": "PPO",
"env": "multi_agent_pendulum",
"stop": {
"timesteps_total": 500000,
"episode_reward_mean": -300.0,
},
"config": {
"train_batch_size": 2048,
"vf_clip_param": 10.0,
"num_workers": 0,
"num_envs_per_worker": 10,
"lambda": 0.1,
"gamma": 0.95,
"lr": 0.0003,
"sgd_minibatch_size": 64,
"num_sgd_iter": 10,
"model": {
"fcnet_hiddens": [128, 128],
},
"batch_mode": "complete_episodes",
"use_pytorch": fw == "torch",
},
"batch_mode": "complete_episodes",
},
}
})
if trials[0].last_result["episode_reward_mean"] < -200:
raise ValueError("Did not get to -200 reward",
trials[0].last_result)
}
})
if trials[0].last_result["episode_reward_mean"] < -300.0:
raise ValueError("Did not get to -200 reward",
trials[0].last_result)


if __name__ == "__main__":
Expand Down
68 changes: 49 additions & 19 deletions rllib/utils/exploration/random.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from gym.spaces import Discrete, MultiDiscrete, Tuple
from gym.spaces import Discrete, Box, MultiDiscrete
import numpy as np
import tree
from typing import Union
Expand All @@ -9,6 +9,7 @@
from ray.rllib.utils import force_tuple
from ray.rllib.utils.framework import try_import_tf, try_import_torch, \
TensorType
from ray.rllib.utils.space_utils import get_base_struct_from_space

tf = try_import_tf()
torch, _ = try_import_torch()
Expand All @@ -35,13 +36,8 @@ def __init__(self, action_space, *, model, framework, **kwargs):
framework=framework,
**kwargs)

# Determine py_func types, depending on our action-space.
if isinstance(self.action_space, (Discrete, MultiDiscrete)) or \
(isinstance(self.action_space, Tuple) and
isinstance(self.action_space[0], (Discrete, MultiDiscrete))):
self.dtype_sample, self.dtype = (tf.int64, tf.int32)
else:
self.dtype_sample, self.dtype = (tf.float64, tf.float32)
self.action_space_struct = get_base_struct_from_space(
self.action_space)

@override(Exploration)
def get_exploration_action(self,
Expand All @@ -59,14 +55,46 @@ def get_exploration_action(self,

def get_tf_exploration_action_op(self, action_dist, explore):
def true_fn():
action = tf.py_function(self.action_space.sample, [],
self.dtype_sample)
# Will be unnecessary, once we support batch/time-aware Spaces.
return tf.expand_dims(tf.cast(action, dtype=self.dtype), 0)
batch_size = 1
req = force_tuple(
action_dist.required_model_output_shape(
self.action_space, self.model.model_config))
# Add a batch dimension?
if len(action_dist.inputs.shape) == len(req) + 1:
batch_size = tf.shape(action_dist.inputs)[0]

# Function to produce random samples from primitive space
# components: (Multi)Discrete or Box.
def random_component(component):
if isinstance(component, Discrete):
return tf.random.uniform(
shape=(batch_size, ) + component.shape,
maxval=component.n,
dtype=component.dtype)
elif isinstance(component, MultiDiscrete):
return tf.random.uniform(
shape=(batch_size, ) + component.shape,
maxval=component.nvec,
dtype=component.dtype)
elif isinstance(component, Box):
if component.bounded_above.all() and \
component.bounded_below.all():
return tf.random.uniform(
shape=(batch_size, ) + component.shape,
minval=component.low,
maxval=component.high,
dtype=component.dtype)
else:
return tf.random.normal(
shape=(batch_size, ) + component.shape,
dtype=component.dtype)

actions = tree.map_structure(random_component,
self.action_space_struct)
return actions

def false_fn():
return tf.cast(
action_dist.deterministic_sample(), dtype=self.dtype)
return action_dist.deterministic_sample()

action = tf.cond(
pred=tf.constant(explore, dtype=tf.bool)
Expand All @@ -81,15 +109,17 @@ def false_fn():

def get_torch_exploration_action(self, action_dist, explore):
if explore:
# Unsqueeze will be unnecessary, once we support batch/time-aware
# Spaces.
a = self.action_space.sample()
req = force_tuple(
action_dist.required_model_output_shape(
self.action_space, self.model.model_config))
# Add a batch dimension.
# Add a batch dimension?
if len(action_dist.inputs.shape) == len(req) + 1:
a = np.expand_dims(a, 0)
batch_size = action_dist.inputs.shape[0]
a = np.stack(
[self.action_space.sample() for _ in range(batch_size)])
else:
a = self.action_space.sample()
# Convert action to torch tensor.
action = torch.from_numpy(a).to(self.device)
else:
action = action_dist.deterministic_sample()
Expand Down

0 comments on commit d7eaacb

Please sign in to comment.