forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
external_env.py
481 lines (410 loc) · 16.7 KB
/
external_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
import gymnasium as gym
import queue
import threading
import uuid
from typing import Callable, Tuple, Optional, TYPE_CHECKING
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.utils.annotations import override, OldAPIStack
from ray.rllib.utils.typing import (
EnvActionType,
EnvInfoDict,
EnvObsType,
EnvType,
MultiEnvDict,
)
from ray.rllib.utils.deprecation import deprecation_warning
if TYPE_CHECKING:
from ray.rllib.models.preprocessors import Preprocessor
@OldAPIStack
class ExternalEnv(threading.Thread):
"""An environment that interfaces with external agents.
Unlike simulator envs, control is inverted: The environment queries the
policy to obtain actions and in return logs observations and rewards for
training. This is in contrast to gym.Env, where the algorithm drives the
simulation through env.step() calls.
You can use ExternalEnv as the backend for policy serving (by serving HTTP
requests in the run loop), for ingesting offline logs data (by reading
offline transitions in the run loop), or other custom use cases not easily
expressed through gym.Env.
ExternalEnv supports both on-policy actions (through self.get_action()),
and off-policy actions (through self.log_action()).
This env is thread-safe, but individual episodes must be executed serially.
.. testcode::
:skipif: True
from ray.tune import register_env
from ray.rllib.algorithms.dqn import DQN
YourExternalEnv = ...
register_env("my_env", lambda config: YourExternalEnv(config))
algo = DQN(env="my_env")
while True:
print(algo.train())
"""
def __init__(
self,
action_space: gym.Space,
observation_space: gym.Space,
max_concurrent: int = None,
):
"""Initializes an ExternalEnv instance.
Args:
action_space: Action space of the env.
observation_space: Observation space of the env.
"""
threading.Thread.__init__(self)
self.daemon = True
self.action_space = action_space
self.observation_space = observation_space
self._episodes = {}
self._finished = set()
self._results_avail_condition = threading.Condition()
if max_concurrent is not None:
deprecation_warning(
"The `max_concurrent` argument has been deprecated. Please configure"
"the number of episodes using the `rollout_fragment_length` and"
"`batch_mode` arguments. Please raise an issue on the Ray Github if "
"these arguments do not support your expected use case for ExternalEnv",
error=True,
)
def run(self):
"""Override this to implement the run loop.
Your loop should continuously:
1. Call self.start_episode(episode_id)
2. Call self.[get|log]_action(episode_id, obs, [action]?)
3. Call self.log_returns(episode_id, reward)
4. Call self.end_episode(episode_id, obs)
5. Wait if nothing to do.
Multiple episodes may be started at the same time.
"""
raise NotImplementedError
def start_episode(
self, episode_id: Optional[str] = None, training_enabled: bool = True
) -> str:
"""Record the start of an episode.
Args:
episode_id: Unique string id for the episode or
None for it to be auto-assigned and returned.
training_enabled: Whether to use experiences for this
episode to improve the policy.
Returns:
Unique string id for the episode.
"""
if episode_id is None:
episode_id = uuid.uuid4().hex
if episode_id in self._finished:
raise ValueError("Episode {} has already completed.".format(episode_id))
if episode_id in self._episodes:
raise ValueError("Episode {} is already started".format(episode_id))
self._episodes[episode_id] = _ExternalEnvEpisode(
episode_id, self._results_avail_condition, training_enabled
)
return episode_id
def get_action(self, episode_id: str, observation: EnvObsType) -> EnvActionType:
"""Record an observation and get the on-policy action.
Args:
episode_id: Episode id returned from start_episode().
observation: Current environment observation.
Returns:
Action from the env action space.
"""
episode = self._get(episode_id)
return episode.wait_for_action(observation)
def log_action(
self, episode_id: str, observation: EnvObsType, action: EnvActionType
) -> None:
"""Record an observation and (off-policy) action taken.
Args:
episode_id: Episode id returned from start_episode().
observation: Current environment observation.
action: Action for the observation.
"""
episode = self._get(episode_id)
episode.log_action(observation, action)
def log_returns(
self, episode_id: str, reward: float, info: Optional[EnvInfoDict] = None
) -> None:
"""Records returns (rewards and infos) from the environment.
The reward will be attributed to the previous action taken by the
episode. Rewards accumulate until the next action. If no reward is
logged before the next action, a reward of 0.0 is assumed.
Args:
episode_id: Episode id returned from start_episode().
reward: Reward from the environment.
info: Optional info dict.
"""
episode = self._get(episode_id)
episode.cur_reward += reward
if info:
episode.cur_info = info or {}
def end_episode(self, episode_id: str, observation: EnvObsType) -> None:
"""Records the end of an episode.
Args:
episode_id: Episode id returned from start_episode().
observation: Current environment observation.
"""
episode = self._get(episode_id)
self._finished.add(episode.episode_id)
episode.done(observation)
def _get(self, episode_id: str) -> "_ExternalEnvEpisode":
"""Get a started episode by its ID or raise an error."""
if episode_id in self._finished:
raise ValueError("Episode {} has already completed.".format(episode_id))
if episode_id not in self._episodes:
raise ValueError("Episode {} not found.".format(episode_id))
return self._episodes[episode_id]
def to_base_env(
self,
make_env: Optional[Callable[[int], EnvType]] = None,
num_envs: int = 1,
remote_envs: bool = False,
remote_env_batch_wait_ms: int = 0,
restart_failed_sub_environments: bool = False,
) -> "BaseEnv":
"""Converts an RLlib MultiAgentEnv into a BaseEnv object.
The resulting BaseEnv is always vectorized (contains n
sub-environments) to support batched forward passes, where n may
also be 1. BaseEnv also supports async execution via the `poll` and
`send_actions` methods and thus supports external simulators.
Args:
make_env: A callable taking an int as input (which indicates
the number of individual sub-environments within the final
vectorized BaseEnv) and returning one individual
sub-environment.
num_envs: The number of sub-environments to create in the
resulting (vectorized) BaseEnv. The already existing `env`
will be one of the `num_envs`.
remote_envs: Whether each sub-env should be a @ray.remote
actor. You can set this behavior in your config via the
`remote_worker_envs=True` option.
remote_env_batch_wait_ms: The wait time (in ms) to poll remote
sub-environments for, if applicable. Only used if
`remote_envs` is True.
Returns:
The resulting BaseEnv object.
"""
if num_envs != 1:
raise ValueError(
"External(MultiAgent)Env does not currently support "
"num_envs > 1. One way of solving this would be to "
"treat your Env as a MultiAgentEnv hosting only one "
"type of agent but with several copies."
)
env = ExternalEnvWrapper(self)
return env
@OldAPIStack
class _ExternalEnvEpisode:
"""Tracked state for each active episode."""
def __init__(
self,
episode_id: str,
results_avail_condition: threading.Condition,
training_enabled: bool,
multiagent: bool = False,
):
self.episode_id = episode_id
self.results_avail_condition = results_avail_condition
self.training_enabled = training_enabled
self.multiagent = multiagent
self.data_queue = queue.Queue()
self.action_queue = queue.Queue()
if multiagent:
self.new_observation_dict = None
self.new_action_dict = None
self.cur_reward_dict = {}
self.cur_terminated_dict = {"__all__": False}
self.cur_truncated_dict = {"__all__": False}
self.cur_info_dict = {}
else:
self.new_observation = None
self.new_action = None
self.cur_reward = 0.0
self.cur_terminated = False
self.cur_truncated = False
self.cur_info = {}
def get_data(self):
if self.data_queue.empty():
return None
return self.data_queue.get_nowait()
def log_action(self, observation, action):
if self.multiagent:
self.new_observation_dict = observation
self.new_action_dict = action
else:
self.new_observation = observation
self.new_action = action
self._send()
self.action_queue.get(True, timeout=60.0)
def wait_for_action(self, observation):
if self.multiagent:
self.new_observation_dict = observation
else:
self.new_observation = observation
self._send()
return self.action_queue.get(True, timeout=300.0)
def done(self, observation):
if self.multiagent:
self.new_observation_dict = observation
self.cur_terminated_dict = {"__all__": True}
# TODO(sven): External env API does not currently support truncated,
# but we should deprecate external Env anyways in favor of a client-only
# approach.
self.cur_truncated_dict = {"__all__": False}
else:
self.new_observation = observation
self.cur_terminated = True
self.cur_truncated = False
self._send()
def _send(self):
if self.multiagent:
if not self.training_enabled:
for agent_id in self.cur_info_dict:
self.cur_info_dict[agent_id]["training_enabled"] = False
item = {
"obs": self.new_observation_dict,
"reward": self.cur_reward_dict,
"terminated": self.cur_terminated_dict,
"truncated": self.cur_truncated_dict,
"info": self.cur_info_dict,
}
if self.new_action_dict is not None:
item["off_policy_action"] = self.new_action_dict
self.new_observation_dict = None
self.new_action_dict = None
self.cur_reward_dict = {}
else:
item = {
"obs": self.new_observation,
"reward": self.cur_reward,
"terminated": self.cur_terminated,
"truncated": self.cur_truncated,
"info": self.cur_info,
}
if self.new_action is not None:
item["off_policy_action"] = self.new_action
self.new_observation = None
self.new_action = None
self.cur_reward = 0.0
if not self.training_enabled:
item["info"]["training_enabled"] = False
with self.results_avail_condition:
self.data_queue.put_nowait(item)
self.results_avail_condition.notify()
@OldAPIStack
class ExternalEnvWrapper(BaseEnv):
"""Internal adapter of ExternalEnv to BaseEnv."""
def __init__(
self, external_env: "ExternalEnv", preprocessor: "Preprocessor" = None
):
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
self.external_env = external_env
self.prep = preprocessor
self.multiagent = issubclass(type(external_env), ExternalMultiAgentEnv)
self._action_space = external_env.action_space
if preprocessor:
self._observation_space = preprocessor.observation_space
else:
self._observation_space = external_env.observation_space
external_env.start()
@override(BaseEnv)
def poll(
self,
) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict]:
with self.external_env._results_avail_condition:
results = self._poll()
while len(results[0]) == 0:
self.external_env._results_avail_condition.wait()
results = self._poll()
if not self.external_env.is_alive():
raise Exception("Serving thread has stopped.")
return results
@override(BaseEnv)
def send_actions(self, action_dict: MultiEnvDict) -> None:
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
if self.multiagent:
for env_id, actions in action_dict.items():
self.external_env._episodes[env_id].action_queue.put(actions)
else:
for env_id, action in action_dict.items():
self.external_env._episodes[env_id].action_queue.put(
action[_DUMMY_AGENT_ID]
)
def _poll(
self,
) -> Tuple[
MultiEnvDict,
MultiEnvDict,
MultiEnvDict,
MultiEnvDict,
MultiEnvDict,
MultiEnvDict,
]:
from ray.rllib.env.base_env import with_dummy_agent_id
all_obs, all_rewards, all_terminateds, all_truncateds, all_infos = (
{},
{},
{},
{},
{},
)
off_policy_actions = {}
for eid, episode in self.external_env._episodes.copy().items():
data = episode.get_data()
cur_terminated = (
episode.cur_terminated_dict["__all__"]
if self.multiagent
else episode.cur_terminated
)
cur_truncated = (
episode.cur_truncated_dict["__all__"]
if self.multiagent
else episode.cur_truncated
)
if cur_terminated or cur_truncated:
del self.external_env._episodes[eid]
if data:
if self.prep:
all_obs[eid] = self.prep.transform(data["obs"])
else:
all_obs[eid] = data["obs"]
all_rewards[eid] = data["reward"]
all_terminateds[eid] = data["terminated"]
all_truncateds[eid] = data["truncated"]
all_infos[eid] = data["info"]
if "off_policy_action" in data:
off_policy_actions[eid] = data["off_policy_action"]
if self.multiagent:
# Ensure a consistent set of keys
# rely on all_obs having all possible keys for now.
for eid, eid_dict in all_obs.items():
for agent_id in eid_dict.keys():
def fix(d, zero_val):
if agent_id not in d[eid]:
d[eid][agent_id] = zero_val
fix(all_rewards, 0.0)
fix(all_terminateds, False)
fix(all_truncateds, False)
fix(all_infos, {})
return (
all_obs,
all_rewards,
all_terminateds,
all_truncateds,
all_infos,
off_policy_actions,
)
else:
return (
with_dummy_agent_id(all_obs),
with_dummy_agent_id(all_rewards),
with_dummy_agent_id(all_terminateds, "__all__"),
with_dummy_agent_id(all_truncateds, "__all__"),
with_dummy_agent_id(all_infos),
with_dummy_agent_id(off_policy_actions),
)
@property
@override(BaseEnv)
def observation_space(self) -> gym.spaces.Dict:
return self._observation_space
@property
@override(BaseEnv)
def action_space(self) -> gym.Space:
return self._action_space