Skip to content

Commit

Permalink
Wrapper for the dm_env interface (ray-project#6468)
Browse files Browse the repository at this point in the history
  • Loading branch information
gehring authored and ericl committed Dec 26, 2019
1 parent b98b288 commit b40869d
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 1 deletion.
3 changes: 2 additions & 1 deletion rllib/env/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.dm_env_wrapper import DMEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.env.serving_env import ServingEnv
Expand All @@ -7,5 +8,5 @@

__all__ = [
"BaseEnv", "MultiAgentEnv", "ExternalEnv", "VectorEnv", "ServingEnv",
"EnvContext"
"EnvContext", "DMEnv"
]
98 changes: 98 additions & 0 deletions rllib/env/dm_env_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import gym
from gym import spaces

import numpy as np

try:
from dm_env import specs
except ImportError:
specs = None


def _convert_spec_to_space(spec):
if isinstance(spec, dict):
return spaces.Dict(
{k: _convert_spec_to_space(v)
for k, v in spec.items()})
if isinstance(spec, specs.DiscreteArray):
return spaces.Discrete(spec.num_values)
elif isinstance(spec, specs.BoundedArray):
return spaces.Box(
low=np.asscalar(spec.minimum),
high=np.asscalar(spec.maximum),
shape=spec.shape,
dtype=spec.dtype)
elif isinstance(spec, specs.Array):
return spaces.Box(
low=-float("inf"),
high=float("inf"),
shape=spec.shape,
dtype=spec.dtype)

raise NotImplementedError(
("Could not convert `Array` spec of type {} to Gym space. "
"Attempted to convert: {}").format(type(spec), spec))


class DMEnv(gym.Env):
"""A `gym.Env` wrapper for the `dm_env` API.
"""

metadata = {"render.modes": ["rgb_array"]}

def __init__(self, dm_env):
super(DMEnv, self).__init__()
self._env = dm_env
self._prev_obs = None

if specs is None:
raise RuntimeError((
"The `specs` module from `dm_env` was not imported. Make sure "
"`dm_env` is installed and visible in the current python "
"environment."))

def step(self, action):
ts = self._env.step(action)

reward = ts.reward
if reward is None:
reward = 0.

return ts.observation, reward, ts.last(), {"discount": ts.discount}

def reset(self):
ts = self._env.reset()
return ts.observation

def render(self, mode="rgb_array"):
if self._prev_obs is None:
raise ValueError(
"Environment not started. Make sure to reset before rendering."
)

if mode == "rgb_array":
return self._prev_obs
else:
raise NotImplementedError(
"Render mode '{}' is not supported.".format(mode))

@property
def action_space(self):
spec = self._env.action_spec()
return _convert_spec_to_space(spec)

@property
def observation_space(self):
spec = self._env.observation_spec()
return _convert_spec_to_space(spec)

@property
def reward_range(self):
spec = self._env.reward_spec()
if isinstance(spec, specs.BoundedArray):
return spec.minimum, spec.maximum
return -float("inf"), float("inf")
33 changes: 33 additions & 0 deletions rllib/examples/dmlab_watermaze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from deepmind_lab import dmenv_module

from ray.rllib import env


class Watermaze(env.DMEnv):
def __init__(self, env_config):
lab = dmenv_module.Lab(
"contributed/dmlab30/rooms_watermaze",
["RGBD"],
config=env_config,
)
super(Watermaze, self).__init__(lab)


env = Watermaze({"width": "320", "height": "160"})
print(env.action_space)

for i in range(2):
print(
env.step({
"CROUCH": 0.,
"FIRE": 0.,
"JUMP": 0.,
"LOOK_DOWN_UP_PIXELS_PER_FRAME": 0.,
"LOOK_LEFT_RIGHT_PIXELS_PER_FRAME": 0.,
"MOVE_BACK_FORWARD": 0.,
"STRAFE_LEFT_RIGHT": 0.
}))

0 comments on commit b40869d

Please sign in to comment.