Skip to content

Commit

Permalink
[rllib] validate observation in NoPreprocessor (ray-project#4546)
Browse files Browse the repository at this point in the history
  • Loading branch information
joneswong authored and ericl committed Apr 7, 2019
1 parent f9b8e77 commit da5a471
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 25 deletions.
6 changes: 3 additions & 3 deletions python/ray/rllib/examples/parametric_action_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@ def __init__(self, max_avail_actions):
self.wrapped = gym.make("CartPole-v0")
self.observation_space = Dict({
"action_mask": Box(0, 1, shape=(max_avail_actions, )),
"avail_actions": Box(-1, 1, shape=(max_avail_actions, 2)),
"avail_actions": Box(-10, 10, shape=(max_avail_actions, 2)),
"cart": self.wrapped.observation_space,
})

def update_avail_actions(self):
self.action_assignments = [[0, 0]] * self.action_space.n
self.action_mask = [0] * self.action_space.n
self.action_assignments = np.array([[0., 0.]] * self.action_space.n)
self.action_mask = np.array([0.] * self.action_space.n)
self.left_idx, self.right_idx = random.sample(
range(self.action_space.n), 2)
self.action_assignments[self.left_idx] = self.left_action_embed
Expand Down
28 changes: 25 additions & 3 deletions python/ray/rllib/models/preprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

ATARI_OBS_SHAPE = (210, 160, 3)
ATARI_RAM_OBS_SHAPE = (128, )
VALIDATION_INTERVAL = 100

logger = logging.getLogger(__name__)

Expand All @@ -31,6 +32,7 @@ def __init__(self, obs_space, options=None):
self._options = options or {}
self.shape = self._init_shape(obs_space, options)
self._size = int(np.product(self.shape))
self._i = 0

@PublicAPI
def _init_shape(self, obs_space, options):
Expand All @@ -46,6 +48,23 @@ def write(self, observation, array, offset):
"""Alternative to transform for more efficient flattening."""
array[offset:offset + self._size] = self.transform(observation)

def check_shape(self, observation):
"""Checks the shape of the given observation."""
if self._i % VALIDATION_INTERVAL == 0:
if type(observation) is list and isinstance(
self._obs_space, gym.spaces.Box):
observation = np.array(observation)
try:
if not self._obs_space.contains(observation):
raise ValueError(
"Observation outside expected value range",
self._obs_space, observation)
except AttributeError:
raise ValueError(
"Observation for a Box space should be an np.array, "
"not a Python list.", observation)
self._i += 1

@property
@PublicAPI
def size(self):
Expand Down Expand Up @@ -85,6 +104,7 @@ def _init_shape(self, obs_space, options):
@override(Preprocessor)
def transform(self, observation):
"""Downsamples images from (210, 160, 3) by the configured factor."""
self.check_shape(observation)
scaled = observation[25:-25, :, :]
if self._dim < 84:
scaled = cv2.resize(scaled, (84, 84))
Expand All @@ -111,6 +131,7 @@ def _init_shape(self, obs_space, options):

@override(Preprocessor)
def transform(self, observation):
self.check_shape(observation)
return (observation - 128) / 128


Expand All @@ -121,10 +142,8 @@ def _init_shape(self, obs_space, options):

@override(Preprocessor)
def transform(self, observation):
self.check_shape(observation)
arr = np.zeros(self._obs_space.n)
if not self._obs_space.contains(observation):
raise ValueError("Observation outside expected value range",
self._obs_space, observation)
arr[observation] = 1
return arr

Expand All @@ -140,6 +159,7 @@ def _init_shape(self, obs_space, options):

@override(Preprocessor)
def transform(self, observation):
self.check_shape(observation)
return observation

@override(Preprocessor)
Expand Down Expand Up @@ -169,6 +189,7 @@ def _init_shape(self, obs_space, options):

@override(Preprocessor)
def transform(self, observation):
self.check_shape(observation)
array = np.zeros(self.shape)
self.write(observation, array, 0)
return array
Expand Down Expand Up @@ -201,6 +222,7 @@ def _init_shape(self, obs_space, options):

@override(Preprocessor)
def transform(self, observation):
self.check_shape(observation)
array = np.zeros(self.shape)
self.write(observation, array, 0)
return array
Expand Down
3 changes: 2 additions & 1 deletion python/ray/rllib/tests/test_avail_actions_qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import division
from __future__ import print_function

import numpy as np
from gym.spaces import Tuple, Discrete, Dict, Box

import ray
Expand All @@ -20,7 +21,7 @@ class AvailActionsTestEnv(MultiAgentEnv):
def __init__(self, env_config):
self.state = None
self.avail = env_config["avail_action"]
self.action_mask = [0] * 10
self.action_mask = np.array([0] * 10)
self.action_mask[env_config["avail_action"]] = 1

def reset(self):
Expand Down
4 changes: 2 additions & 2 deletions python/ray/rllib/tests/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ class TupleEnv(object):
def __init__(self):
self.observation_space = Tuple(
[Discrete(5),
Box(0, 1, shape=(3, ), dtype=np.float32)])
Box(0, 5, shape=(3, ), dtype=np.float32)])

p1 = ModelCatalog.get_preprocessor(TupleEnv())
self.assertEqual(p1.shape, (8, ))
self.assertEqual(
list(p1.transform((0, [1, 2, 3]))),
list(p1.transform((0, np.array([1, 2, 3])))),
[float(x) for x in [1, 0, 0, 0, 0, 1, 2, 3]])

def testCustomPreprocessor(self):
Expand Down
13 changes: 11 additions & 2 deletions python/ray/rllib/tests/test_checkpoint_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import os
import shutil
import gym
import numpy as np
import ray

Expand Down Expand Up @@ -63,9 +64,11 @@ def test_ckpt_restore(use_object_store, alg_name, failures):
if "DDPG" in alg_name:
alg1 = cls(config=CONFIGS[name], env="Pendulum-v0")
alg2 = cls(config=CONFIGS[name], env="Pendulum-v0")
env = gym.make("Pendulum-v0")
else:
alg1 = cls(config=CONFIGS[name], env="CartPole-v0")
alg2 = cls(config=CONFIGS[name], env="CartPole-v0")
env = gym.make("CartPole-v0")

for _ in range(3):
res = alg1.train()
Expand All @@ -79,9 +82,15 @@ def test_ckpt_restore(use_object_store, alg_name, failures):

for _ in range(10):
if "DDPG" in alg_name:
obs = np.random.uniform(size=3)
obs = np.clip(
np.random.uniform(size=3),
env.observation_space.low,
env.observation_space.high)
else:
obs = np.random.uniform(size=4)
obs = np.clip(
np.random.uniform(size=4),
env.observation_space.low,
env.observation_space.high)
a1 = get_mean_action(alg1, obs)
a2 = get_mean_action(alg2, obs)
print("Checking computed actions", alg1, obs, a1, a2)
Expand Down
36 changes: 36 additions & 0 deletions python/ray/rllib/tests/test_perf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import gym
import time
import unittest

import ray
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.tests.test_policy_evaluator import MockPolicyGraph


class TestPerf(unittest.TestCase):
# Tested on Intel(R) Core(TM) i7-4600U CPU @ 2.10GHz
# 11/23/18: Samples per second 8501.125113727468
# 03/01/19: Samples per second 8610.164353268685
def testBaselinePerformance(self):
for _ in range(20):
ev = PolicyEvaluator(
env_creator=lambda _: gym.make("CartPole-v0"),
policy_graph=MockPolicyGraph,
batch_steps=100)
start = time.time()
count = 0
while time.time() - start < 1:
count += ev.sample().count
print()
print("Samples per second {}".format(
count / (time.time() - start)))
print()


if __name__ == "__main__":
ray.init(num_cpus=5)
unittest.main(verbosity=2)
14 changes: 0 additions & 14 deletions python/ray/rllib/tests/test_policy_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,20 +166,6 @@ def testBatchIds(self):
self.assertEqual(
len(set(SampleBatch.concat(batch1, batch2)["unroll_id"])), 2)

# 11/23/18: Samples per second 8501.125113727468
def testBaselinePerformance(self):
ev = PolicyEvaluator(
env_creator=lambda _: gym.make("CartPole-v0"),
policy_graph=MockPolicyGraph,
batch_steps=100)
start = time.time()
count = 0
while time.time() - start < 1:
count += ev.sample().count
print()
print("Samples per second {}".format(count / (time.time() - start)))
print()

def testGlobalVarsUpdate(self):
agent = A2CTrainer(
env="CartPole-v0",
Expand Down

0 comments on commit da5a471

Please sign in to comment.