Skip to content

Commit

Permalink
[RLlib] Issue 18231: Better (earlier) env validation and error messag…
Browse files Browse the repository at this point in the history
…e improvement. (ray-project#18249)
  • Loading branch information
sven1977 authored Sep 2, 2021
1 parent 6621bb5 commit 2357bbc
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 40 deletions.
2 changes: 1 addition & 1 deletion rllib/env/vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def try_render_at(self, index: Optional[int] = None) -> \


class _VectorizedGymEnv(VectorEnv):
"""Internal wrapper to translate any gym envs into a VectorEnv object.
"""Internal wrapper to translate any gym.Envs into a VectorEnv object.
"""

def __init__(
Expand Down
4 changes: 2 additions & 2 deletions rllib/env/wrappers/model_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ def model_vector_env(env: EnvType) -> BaseEnv:
worker_index = worker.worker_index
if worker_index:
env = _VectorizedModelGymEnv(
make_env=worker.make_env_fn,
make_env=worker.make_sub_env_fn,
existing_envs=[env],
num_envs=worker.num_envs,
observation_space=env.observation_space,
action_space=env.action_space,
)
return BaseEnv.to_base_env(
env,
make_env=worker.make_env_fn,
make_env=worker.make_sub_env_fn,
num_envs=worker.num_envs,
remote_envs=False,
remote_env_batch_wait_ms=0)
Expand Down
112 changes: 82 additions & 30 deletions rllib/evaluation/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.debug import summarize, update_global_seed_if_necessary
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
from ray.rllib.utils.error import EnvError
from ray.rllib.utils.filter import get_filter, Filter
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.sgd import do_minibatch_sgd
Expand Down Expand Up @@ -417,18 +418,29 @@ def gen_rollouts():
self.global_vars: dict = None
self.fake_sampler: bool = fake_sampler

self.env = None

# Update the global seed for numpy/random/tf-eager/torch.
update_global_seed_if_necessary(policy_config.get("framework"), seed)

# Create an env for this worker.
# A single environment provided by the user (via config.env). This may
# also remain None.
# 1) Create the env using the user provided env_creator. This may
# return a gym.Env (incl. MultiAgentEnv), an already vectorized
# VectorEnv, BaseEnv, ExternalEnv, or an ActorHandle (remote env).
# 2) Wrap - if applicable - with Atari/recording/rendering wrappers.
# 3) Seed the env, if necessary.
# 4) Vectorize the existing single env by creating more clones of
# this env and wrapping it with the RLlib BaseEnv class.
self.env = None

# Create a (single) env for this worker.
if not (worker_index == 0 and num_workers > 0
and not policy_config.get("create_env_on_driver")):
# Run the `env_creator` function passing the EnvContext.
self.env = env_creator(env_context)

if self.env is not None:
# Validate environment (general validation function).
self.env = _validate_env(self.env)
_validate_env(self.env, env_context=self.env_context)
# Custom validation function given.
if validate_env is not None:
validate_env(self.env, self.env_context)
Expand Down Expand Up @@ -507,28 +519,39 @@ def wrap(env):

# Wrap env through the correct wrapper.
self.env: EnvType = wrap(self.env)
# Ideally, we would use the same make_env() function below
# Ideally, we would use the same make_sub_env() function below
# to create self.env, but wrap(env) and self.env has a cyclic
# dependency on each other right now, so we would settle on
# duplicating the random seed setting logic for now.
_update_env_seed_if_necessary(self.env, seed, worker_index, 0)

def make_env(vector_index):
env = wrap(
env_creator(
env_context.copy_with_overrides(
worker_index=worker_index,
vector_index=vector_index,
remote=remote_worker_envs)))
# make_env() is used to created additional environments
# during environment vectorization below.
# So we make sure a deterministic random seed is set on
# all the environments if specified.
def make_sub_env(vector_index):
# Used to created additional environments during environment
# vectorization.

# Create the env context (config dict + meta-data) for
# this particular sub-env within the vectorized one.
env_ctx = env_context.copy_with_overrides(
worker_index=worker_index,
vector_index=vector_index,
remote=remote_worker_envs)
# Create the sub-env.
env = env_creator(env_ctx)
# Validate first.
_validate_env(env, env_context=env_ctx)
# Custom validation function given by user.
if validate_env is not None:
validate_env(env, env_ctx)
# Use our wrapper, defined above.
env = wrap(env)

# Make sure a deterministic random seed is set on
# all the sub-environments if specified.
_update_env_seed_if_necessary(env, seed, worker_index,
vector_index)
return env

self.make_env_fn = make_env
self.make_sub_env_fn = make_sub_env
self.spaces = spaces

policy_dict = _determine_spaces_for_multi_agent_dict(
Expand Down Expand Up @@ -609,18 +632,22 @@ def make_env(vector_index):
if self.worker_index == 0:
logger.info("Built filter map: {}".format(self.filters))

# Vectorize environment, if any.
self.num_envs: int = num_envs

# This RolloutWorker has no env.
if self.env is None:
self.async_env = None
# Use a custom env-vectorizer and call it providing self.env.
elif "custom_vector_env" in policy_config:
custom_vec_wrapper = policy_config["custom_vector_env"]
self.async_env = custom_vec_wrapper(self.env)
self.async_env = policy_config["custom_vector_env"](self.env)
# Default: Vectorize self.env via the make_sub_env function. This adds
# further clones of self.env and creates a RLlib BaseEnv (which is
# vectorized under the hood).
else:
# Always use vector env for consistency even if num_envs = 1.
self.async_env: BaseEnv = BaseEnv.to_base_env(
self.env,
make_env=make_env,
make_env=self.make_sub_env_fn,
num_envs=num_envs,
remote_envs=remote_worker_envs,
remote_env_batch_wait_ms=remote_env_batch_wait_ms,
Expand Down Expand Up @@ -1478,18 +1505,43 @@ def _determine_spaces_for_multi_agent_dict(
return multi_agent_dict


def _validate_env(env: Any) -> EnvType:
# Allow this as a special case (assumed gym.Env).
if hasattr(env, "observation_space") and hasattr(env, "action_space"):
return env
def _validate_env(env: EnvType, env_context: EnvContext = None):
# Base message for checking the env for vector-index=0
msg = f"Validating sub-env at vector index={env_context.vector_index} ..."

allowed_types = [
gym.Env, MultiAgentEnv, ExternalEnv, VectorEnv, BaseEnv,
ray.actor.ActorHandle
]
if not any(isinstance(env, tpe) for tpe in allowed_types):
raise ValueError(
"Returned env should be an instance of gym.Env, MultiAgentEnv, "
"ExternalEnv, VectorEnv, or BaseEnv. The provided env creator "
"function returned {} ({}).".format(env, type(env)))
return env
# Allow this as a special case (assumed gym.Env).
# TODO: Disallow this early-out. Everything should conform to a few
# supported classes, i.e. gym.Env/MultiAgentEnv/etc...
if hasattr(env, "observation_space") and hasattr(env, "action_space"):
logger.warning(msg + f" (warning; invalid env-type={type(env)})")
return
else:
logger.warning(msg + " (NOT OK)")
raise EnvError(
"Returned env should be an instance of gym.Env (incl. "
"MultiAgentEnv), ExternalEnv, VectorEnv, or BaseEnv. "
f"The provided env creator function returned {env} "
f"(type={type(env)}).")

# Do some test runs with the provided env.
if isinstance(env, gym.Env):
# Make sure the gym.Env has the two space attributes properly set.
assert hasattr(env, "observation_space") and hasattr(
env, "action_space")
# Get a dummy observation by resetting the env.
dummy_obs = env.reset()
# Check, if observation is ok (part of the observation space). If not,
# error.
if not env.observation_space.contains(dummy_obs):
logger.warning(msg + " (NOT OK)")
raise EnvError(
f"Env's `observation_space` {env.observation_space} does not "
f"contain returned observation after a reset ({dummy_obs})!")

# Log that everything is ok.
logger.info(msg + " (ok)")
4 changes: 2 additions & 2 deletions rllib/evaluation/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import threading
import time
import tree # pip install dm_tree
from typing import Any, Callable, Dict, List, Iterable, Optional, Set, Tuple,\
from typing import Any, Callable, Dict, List, Iterator, Optional, Set, Tuple,\
Type, TYPE_CHECKING, Union

from ray.util.debug import log_once
Expand Down Expand Up @@ -460,7 +460,7 @@ def _env_runner(
observation_fn: "ObservationFunction",
sample_collector: Optional[SampleCollector] = None,
render: bool = None,
) -> Iterable[SampleBatchType]:
) -> Iterator[SampleBatchType]:
"""This implements the common experience collection logic.
Args:
Expand Down
11 changes: 6 additions & 5 deletions rllib/evaluation/tests/test_rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,14 @@ def test_global_vars_update(self):
def test_no_step_on_init(self):
register_env("fail", lambda _: FailOnStepEnv())
for fw in framework_iterator():
pg = PGTrainer(
# We expect this to fail already on Trainer init due
# to the env sanity check right after env creation (inside
# RolloutWorker).
self.assertRaises(Exception, lambda: PGTrainer(
env="fail", config={
"num_workers": 1,
"num_workers": 2,
"framework": fw,
})
self.assertRaises(Exception, lambda: pg.train())
pg.stop()
}))

def test_callbacks(self):
for fw in framework_iterator(frameworks=("torch", "tf")):
Expand Down
1 change: 1 addition & 0 deletions rllib/examples/env/env_using_remote_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(self, env_config):
self._handler = ray.get_actor(
env_config.get("param_server", "param-server"))
self.rng_seed = None
self.np_random, _ = seeding.np_random(self.rng_seed)

def seed(self, rng_seed: int = None):
if not rng_seed:
Expand Down
6 changes: 6 additions & 0 deletions rllib/utils/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,9 @@
class UnsupportedSpaceException(Exception):
"""Error for an unsupported action or observation space."""
pass


@PublicAPI
class EnvError(Exception):
"""Error if we encounter an error during RL environment validation."""
pass

0 comments on commit 2357bbc

Please sign in to comment.