forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Wrapper for the dm_env interface (ray-project#6468)
- Loading branch information
Showing
3 changed files
with
133 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import gym | ||
from gym import spaces | ||
|
||
import numpy as np | ||
|
||
try: | ||
from dm_env import specs | ||
except ImportError: | ||
specs = None | ||
|
||
|
||
def _convert_spec_to_space(spec): | ||
if isinstance(spec, dict): | ||
return spaces.Dict( | ||
{k: _convert_spec_to_space(v) | ||
for k, v in spec.items()}) | ||
if isinstance(spec, specs.DiscreteArray): | ||
return spaces.Discrete(spec.num_values) | ||
elif isinstance(spec, specs.BoundedArray): | ||
return spaces.Box( | ||
low=np.asscalar(spec.minimum), | ||
high=np.asscalar(spec.maximum), | ||
shape=spec.shape, | ||
dtype=spec.dtype) | ||
elif isinstance(spec, specs.Array): | ||
return spaces.Box( | ||
low=-float("inf"), | ||
high=float("inf"), | ||
shape=spec.shape, | ||
dtype=spec.dtype) | ||
|
||
raise NotImplementedError( | ||
("Could not convert `Array` spec of type {} to Gym space. " | ||
"Attempted to convert: {}").format(type(spec), spec)) | ||
|
||
|
||
class DMEnv(gym.Env): | ||
"""A `gym.Env` wrapper for the `dm_env` API. | ||
""" | ||
|
||
metadata = {"render.modes": ["rgb_array"]} | ||
|
||
def __init__(self, dm_env): | ||
super(DMEnv, self).__init__() | ||
self._env = dm_env | ||
self._prev_obs = None | ||
|
||
if specs is None: | ||
raise RuntimeError(( | ||
"The `specs` module from `dm_env` was not imported. Make sure " | ||
"`dm_env` is installed and visible in the current python " | ||
"environment.")) | ||
|
||
def step(self, action): | ||
ts = self._env.step(action) | ||
|
||
reward = ts.reward | ||
if reward is None: | ||
reward = 0. | ||
|
||
return ts.observation, reward, ts.last(), {"discount": ts.discount} | ||
|
||
def reset(self): | ||
ts = self._env.reset() | ||
return ts.observation | ||
|
||
def render(self, mode="rgb_array"): | ||
if self._prev_obs is None: | ||
raise ValueError( | ||
"Environment not started. Make sure to reset before rendering." | ||
) | ||
|
||
if mode == "rgb_array": | ||
return self._prev_obs | ||
else: | ||
raise NotImplementedError( | ||
"Render mode '{}' is not supported.".format(mode)) | ||
|
||
@property | ||
def action_space(self): | ||
spec = self._env.action_spec() | ||
return _convert_spec_to_space(spec) | ||
|
||
@property | ||
def observation_space(self): | ||
spec = self._env.observation_spec() | ||
return _convert_spec_to_space(spec) | ||
|
||
@property | ||
def reward_range(self): | ||
spec = self._env.reward_spec() | ||
if isinstance(spec, specs.BoundedArray): | ||
return spec.minimum, spec.maximum | ||
return -float("inf"), float("inf") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
from deepmind_lab import dmenv_module | ||
|
||
from ray.rllib import env | ||
|
||
|
||
class Watermaze(env.DMEnv): | ||
def __init__(self, env_config): | ||
lab = dmenv_module.Lab( | ||
"contributed/dmlab30/rooms_watermaze", | ||
["RGBD"], | ||
config=env_config, | ||
) | ||
super(Watermaze, self).__init__(lab) | ||
|
||
|
||
env = Watermaze({"width": "320", "height": "160"}) | ||
print(env.action_space) | ||
|
||
for i in range(2): | ||
print( | ||
env.step({ | ||
"CROUCH": 0., | ||
"FIRE": 0., | ||
"JUMP": 0., | ||
"LOOK_DOWN_UP_PIXELS_PER_FRAME": 0., | ||
"LOOK_LEFT_RIGHT_PIXELS_PER_FRAME": 0., | ||
"MOVE_BACK_FORWARD": 0., | ||
"STRAFE_LEFT_RIGHT": 0. | ||
})) |