Skip to content

Commit

Permalink
[rllib] Fix rollout.py with tuple action space (ray-project#5201)
Browse files Browse the repository at this point in the history
* fix it

* update doc too

* fix rollout
  • Loading branch information
ericl authored Jul 16, 2019
1 parent 8065243 commit 047f4cc
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 6 deletions.
2 changes: 1 addition & 1 deletion ci/travis/format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ YAPF_VERSION=$(yapf --version | awk '{print $2}')
tool_version_check() {
if [[ $2 != $3 ]]; then
echo "WARNING: Ray uses $1 $3, You currently are using $2. This might generate different results."
read -p "Do you want to continue?[y/n]" answer
read -p "Do you want to continue? [y/n] " answer
if ! [ $answer = 'y' ] && ! [ $answer = 'Y' ]; then
exit 1
fi
Expand Down
14 changes: 13 additions & 1 deletion doc/source/rllib-concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,19 @@ In summary, the main differences between the PyTorch and TensorFlow policy build
Extending Existing Policies
~~~~~~~~~~~~~~~~~~~~~~~~~~~

(todo)
You can use the ``with_updates`` method on Trainers and Policy objects built with ``make_*`` to create a copy of the object with some changes, for example:

.. code-block:: python
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.agents.ppo.ppo_policy import PPOTFPolicy
CustomPolicy = PPOTFPolicy.with_updates(
name="MyCustomPPOTFPolicy",
loss_fn=some_custom_loss_fn)
CustomTrainer = PPOTrainer.with_updates(
default_policy=CustomPolicy)
Policy Evaluation
-----------------
Expand Down
5 changes: 1 addition & 4 deletions python/ray/rllib/evaluation/episode.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,6 @@ def _flatten_action(action):
if isinstance(action, list) or isinstance(action, tuple):
expanded = []
for a in action:
if not hasattr(a, "shape") or len(a.shape) == 0:
expanded.append(np.expand_dims(a, 1))
else:
expanded.append(a)
expanded.append(np.reshape(a, [-1]))
action = np.concatenate(expanded, axis=0).flatten()
return action
2 changes: 2 additions & 0 deletions python/ray/rllib/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ray.rllib.agents.registry import get_agent_class
from ray.rllib.env import MultiAgentEnv
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
from ray.rllib.evaluation.episode import _flatten_action
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.tune.util import merge_dicts

Expand Down Expand Up @@ -176,6 +177,7 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True):
prev_action=prev_actions[agent_id],
prev_reward=prev_rewards[agent_id],
policy_id=policy_id)
a_action = _flatten_action(a_action) # tuple actions
action_dict[agent_id] = a_action
prev_actions[agent_id] = a_action
action = action_dict
Expand Down

0 comments on commit 047f4cc

Please sign in to comment.