forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
env_runner_group.py
1253 lines (1118 loc) · 51 KB
/
env_runner_group.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import functools
import gymnasium as gym
import logging
import importlib.util
import os
from typing import (
Any,
Callable,
Collection,
Dict,
List,
Optional,
Tuple,
Type,
TYPE_CHECKING,
TypeVar,
Union,
)
import ray
from ray.actor import ActorHandle
from ray.exceptions import RayActorError
from ray.rllib.core import (
COMPONENT_ENV_TO_MODULE_CONNECTOR,
COMPONENT_LEARNER,
COMPONENT_MODULE_TO_ENV_CONNECTOR,
COMPONENT_RL_MODULE,
)
from ray.rllib.core.learner import LearnerGroup
from ray.rllib.core.rl_module import validate_module_id
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.env_runner import EnvRunner
from ray.rllib.offline import get_dataset_and_shards
from ray.rllib.policy.policy import Policy, PolicyState
from ray.rllib.utils.actor_manager import FaultTolerantActorManager
from ray.rllib.utils.deprecation import (
Deprecated,
deprecation_warning,
DEPRECATED_VALUE,
)
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.metrics import NUM_ENV_STEPS_SAMPLED_LIFETIME, WEIGHTS_SEQ_NO
from ray.rllib.utils.typing import (
AgentID,
EnvCreator,
EnvType,
EpisodeID,
PartialAlgorithmConfigDict,
PolicyID,
SampleBatchType,
TensorType,
)
from ray.util.annotations import DeveloperAPI
if TYPE_CHECKING:
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__)
# Generic type var for foreach_* methods.
T = TypeVar("T")
@DeveloperAPI
class EnvRunnerGroup:
"""Set of EnvRunners with n @ray.remote workers and zero or one local worker.
Where: n >= 0.
"""
def __init__(
self,
*,
env_creator: Optional[EnvCreator] = None,
validate_env: Optional[Callable[[EnvType], None]] = None,
default_policy_class: Optional[Type[Policy]] = None,
config: Optional["AlgorithmConfig"] = None,
num_env_runners: int = 0,
local_env_runner: bool = True,
logdir: Optional[str] = None,
_setup: bool = True,
tune_trial_id: Optional[str] = None,
# Deprecated args.
num_workers=DEPRECATED_VALUE,
local_worker=DEPRECATED_VALUE,
):
"""Initializes a EnvRunnerGroup instance.
Args:
env_creator: Function that returns env given env config.
validate_env: Optional callable to validate the generated
environment (only on worker=0). This callable should raise
an exception if the environment is invalid.
default_policy_class: An optional default Policy class to use inside
the (multi-agent) `policies` dict. In case the PolicySpecs in there
have no class defined, use this `default_policy_class`.
If None, PolicySpecs will be using the Algorithm's default Policy
class.
config: Optional AlgorithmConfig (or config dict).
num_env_runners: Number of remote EnvRunners to create.
local_env_runner: Whether to create a local (non @ray.remote) EnvRunner
in the returned set as well (default: True). If `num_env_runners`
is 0, always create a local EnvRunner.
logdir: Optional logging directory for workers.
_setup: Whether to actually set up workers. This is only for testing.
"""
if num_workers != DEPRECATED_VALUE or local_worker != DEPRECATED_VALUE:
deprecation_warning(
old="WorkerSet(num_workers=..., local_worker=...)",
new="EnvRunnerGroup(num_env_runners=..., local_env_runner=...)",
error=True,
)
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
# Make sure `config` is an AlgorithmConfig object.
if not config:
config = AlgorithmConfig()
elif isinstance(config, dict):
config = AlgorithmConfig.from_dict(config)
self._env_creator = env_creator
self._policy_class = default_policy_class
self._remote_config = config
self._remote_args = {
"num_cpus": self._remote_config.num_cpus_per_env_runner,
"num_gpus": self._remote_config.num_gpus_per_env_runner,
"resources": self._remote_config.custom_resources_per_env_runner,
"max_restarts": (
config.max_num_env_runner_restarts
if config.restart_failed_env_runners
else 0
),
}
self._tune_trial_id = tune_trial_id
# Set the EnvRunner subclass to be used as "workers". Default: RolloutWorker.
self.env_runner_cls = config.env_runner_cls
if self.env_runner_cls is None:
if config.enable_env_runner_and_connector_v2:
# If experiences should be recorded, use the `
# OfflineSingleAgentEnvRunner`.
if config.output:
# No multi-agent support.
if config.is_multi_agent():
raise ValueError("Multi-agent recording is not supported, yet.")
# Otherwise, load the single-agent env runner for
# recording.
else:
from ray.rllib.offline.offline_env_runner import (
OfflineSingleAgentEnvRunner,
)
self.env_runner_cls = OfflineSingleAgentEnvRunner
else:
if config.is_multi_agent():
from ray.rllib.env.multi_agent_env_runner import (
MultiAgentEnvRunner,
)
self.env_runner_cls = MultiAgentEnvRunner
else:
from ray.rllib.env.single_agent_env_runner import (
SingleAgentEnvRunner,
)
self.env_runner_cls = SingleAgentEnvRunner
else:
self.env_runner_cls = RolloutWorker
self._cls = ray.remote(**self._remote_args)(self.env_runner_cls).remote
self._logdir = logdir
self._ignore_ray_errors_on_env_runners = (
config.ignore_env_runner_failures or config.restart_failed_env_runners
)
# Create remote worker manager.
# ID=0 is used by the local worker.
# Starting remote workers from ID=1 to avoid conflicts.
self._worker_manager = FaultTolerantActorManager(
max_remote_requests_in_flight_per_actor=(
config.max_requests_in_flight_per_env_runner
),
init_id=1,
)
if _setup:
try:
self._setup(
validate_env=validate_env,
config=config,
num_env_runners=num_env_runners,
local_env_runner=local_env_runner,
)
# EnvRunnerGroup creation possibly fails, if some (remote) workers cannot
# be initialized properly (due to some errors in the EnvRunners's
# constructor).
except RayActorError as e:
# In case of an actor (remote worker) init failure, the remote worker
# may still exist and will be accessible, however, e.g. calling
# its `sample.remote()` would result in strange "property not found"
# errors.
if e.actor_init_failed:
# Raise the original error here that the EnvRunners raised
# during its construction process. This is to enforce transparency
# for the user (better to understand the real reason behind the
# failure).
# - e.args[0]: The RayTaskError (inside the caught RayActorError).
# - e.args[0].args[2]: The original Exception (e.g. a ValueError due
# to a config mismatch) thrown inside the actor.
raise e.args[0].args[2]
# In any other case, raise the RayActorError as-is.
else:
raise e
def _setup(
self,
*,
validate_env: Optional[Callable[[EnvType], None]] = None,
config: Optional["AlgorithmConfig"] = None,
num_env_runners: int = 0,
local_env_runner: bool = True,
):
"""Sets up an EnvRunnerGroup instance.
Args:
validate_env: Optional callable to validate the generated
environment (only on worker=0).
config: Optional dict that extends the common config of
the Algorithm class.
num_env_runners: Number of remote EnvRunner workers to create.
local_env_runner: Whether to create a local (non @ray.remote) EnvRunner
in the returned set as well (default: True). If `num_env_runners`
is 0, always create a local EnvRunner.
"""
# Force a local worker if num_env_runners == 0 (no remote workers).
# Otherwise, this EnvRunnerGroup would be empty.
self._local_env_runner = None
if num_env_runners == 0:
local_env_runner = True
# Create a local (learner) version of the config for the local worker.
# The only difference is the tf_session_args, which - for the local worker -
# will be `config.tf_session_args` updated/overridden with
# `config.local_tf_session_args`.
local_tf_session_args = config.tf_session_args.copy()
local_tf_session_args.update(config.local_tf_session_args)
self._local_config = config.copy(copy_frozen=False).framework(
tf_session_args=local_tf_session_args
)
if config.input_ == "dataset":
# Create the set of dataset readers to be shared by all the
# rollout workers.
self._ds, self._ds_shards = get_dataset_and_shards(config, num_env_runners)
else:
self._ds = None
self._ds_shards = None
# Create a number of @ray.remote workers.
self.add_workers(
num_env_runners,
validate=config.validate_env_runners_after_construction,
)
# If num_env_runners > 0 and we don't have an env on the local worker,
# get the observation- and action spaces for each policy from
# the first remote worker (which does have an env).
if (
local_env_runner
and self._worker_manager.num_actors() > 0
and not config.enable_env_runner_and_connector_v2
and not config.create_env_on_local_worker
and (not config.observation_space or not config.action_space)
):
spaces = self.get_spaces()
else:
spaces = None
# Create a local worker, if needed.
if local_env_runner:
self._local_env_runner = self._make_worker(
cls=self.env_runner_cls,
env_creator=self._env_creator,
validate_env=validate_env,
worker_index=0,
num_workers=num_env_runners,
config=self._local_config,
spaces=spaces,
)
def get_spaces(self):
"""Infer observation and action spaces from one (local or remote) EnvRunner.
Returns:
A dict mapping from ModuleID to a 2-tuple containing obs- and action-space.
"""
# Get ID of the first remote worker.
remote_worker_ids = (
[self._worker_manager.actor_ids()[0]]
if self._worker_manager.actor_ids()
else []
)
spaces = self.foreach_worker(
lambda env_runner: env_runner.get_spaces(),
remote_worker_ids=remote_worker_ids,
local_env_runner=not remote_worker_ids,
)[0]
logger.info(
"Inferred observation/action spaces from remote "
f"worker (local worker has no env): {spaces}"
)
return spaces
@property
def local_env_runner(self) -> EnvRunner:
"""Returns the local EnvRunner."""
return self._local_env_runner
@DeveloperAPI
def healthy_worker_ids(self) -> List[int]:
"""Returns the list of remote worker IDs."""
return self._worker_manager.healthy_actor_ids()
@DeveloperAPI
def num_remote_env_runners(self) -> int:
"""Returns the number of remote EnvRunners."""
return self._worker_manager.num_actors()
@DeveloperAPI
def num_remote_workers(self) -> int:
"""Returns the number of remote EnvRunners."""
return self._worker_manager.num_actors()
@DeveloperAPI
def num_healthy_remote_workers(self) -> int:
"""Returns the number of healthy remote workers."""
return self._worker_manager.num_healthy_actors()
@DeveloperAPI
def num_healthy_workers(self) -> int:
"""Returns the number of all healthy workers, including the local worker."""
return int(bool(self._local_env_runner)) + self.num_healthy_remote_workers()
@DeveloperAPI
def num_in_flight_async_reqs(self) -> int:
"""Returns the number of in-flight async requests."""
return self._worker_manager.num_outstanding_async_reqs()
@DeveloperAPI
def num_remote_worker_restarts(self) -> int:
"""Total number of times managed remote workers have been restarted."""
return self._worker_manager.total_num_restarts()
@DeveloperAPI
def sync_env_runner_states(
self,
*,
config: "AlgorithmConfig",
from_worker: Optional[EnvRunner] = None,
env_steps_sampled: Optional[int] = None,
connector_states: Optional[List[Dict[str, Any]]] = None,
rl_module_state: Optional[Dict[str, Any]] = None,
env_runner_indices_to_update: Optional[List[int]] = None,
) -> None:
"""Synchronizes the connectors of this EnvRunnerGroup's EnvRunners.
The exact procedure works as follows:
- If `from_worker` is None, set `from_worker=self.local_env_runner`.
- If `config.use_worker_filter_stats` is True, gather all remote EnvRunners'
ConnectorV2 states. Otherwise, only use the ConnectorV2 states of `from_worker`.
- Merge all gathered states into one resulting state.
- Broadcast the resulting state back to all remote EnvRunners AND the local
EnvRunner.
Args:
config: The AlgorithmConfig object to use to determine, in which
direction(s) we need to synch and what the timeouts are.
from_worker: The EnvRunner from which to synch. If None, will use the local
worker of this EnvRunnerGroup.
env_steps_sampled: The total number of env steps taken thus far by all
workers combined. Used to broadcast this number to all remote workers
if `update_worker_filter_stats` is True in `config`.
env_runner_indices_to_update: The indices of those EnvRunners to update
with the merged state. Use None (default) to update all remote
EnvRunners.
"""
from_worker = from_worker or self.local_env_runner
# Early out if the number of (healthy) remote workers is 0. In this case, the
# local worker is the only operating worker and thus of course always holds
# the reference connector state.
if self.num_healthy_remote_workers() == 0:
self.local_env_runner.set_state(
{
**(
{NUM_ENV_STEPS_SAMPLED_LIFETIME: env_steps_sampled}
if env_steps_sampled is not None
else {}
),
**(rl_module_state if rl_module_state is not None else {}),
}
)
return
# Also early out, if we a) don't use the remote states AND b) don't want to
# broadcast back from `from_worker` to all remote workers.
# TODO (sven): Rename these to proper "..env_runner_states.." containing names.
if not config.update_worker_filter_stats and not config.use_worker_filter_stats:
return
# Use states from all remote EnvRunners.
if config.use_worker_filter_stats:
if connector_states == []:
env_runner_states = {}
else:
if connector_states is None:
connector_states = self.foreach_worker(
lambda w: w.get_state(
components=[
COMPONENT_ENV_TO_MODULE_CONNECTOR,
COMPONENT_MODULE_TO_ENV_CONNECTOR,
]
),
local_env_runner=False,
timeout_seconds=(
config.sync_filters_on_rollout_workers_timeout_s
),
)
env_to_module_states = [
s[COMPONENT_ENV_TO_MODULE_CONNECTOR] for s in connector_states
]
module_to_env_states = [
s[COMPONENT_MODULE_TO_ENV_CONNECTOR] for s in connector_states
]
env_runner_states = {
COMPONENT_ENV_TO_MODULE_CONNECTOR: (
self.local_env_runner._env_to_module.merge_states(
env_to_module_states
)
),
COMPONENT_MODULE_TO_ENV_CONNECTOR: (
self.local_env_runner._module_to_env.merge_states(
module_to_env_states
)
),
}
# Ignore states from remote EnvRunners (use the current `from_worker` states
# only).
else:
env_runner_states = from_worker.get_state(
components=[
COMPONENT_ENV_TO_MODULE_CONNECTOR,
COMPONENT_MODULE_TO_ENV_CONNECTOR,
]
)
# Update the global number of environment steps, if necessary.
# Make sure to divide by the number of env runners (such that each EnvRunner
# knows (roughly) its own(!) lifetime count and can infer the global lifetime
# count from it).
if env_steps_sampled is not None:
env_runner_states[NUM_ENV_STEPS_SAMPLED_LIFETIME] = env_steps_sampled // (
config.num_env_runners or 1
)
# Update the rl_module component of the EnvRunner states, if necessary:
if rl_module_state:
env_runner_states.update(rl_module_state)
# If we do NOT want remote EnvRunners to get their Connector states updated,
# only update the local worker here (with all state components) and then remove
# the connector components.
if not config.update_worker_filter_stats:
self.local_env_runner.set_state(env_runner_states)
del env_runner_states[COMPONENT_ENV_TO_MODULE_CONNECTOR]
del env_runner_states[COMPONENT_MODULE_TO_ENV_CONNECTOR]
# If there are components in the state left -> Update remote workers with these
# state components (and maybe the local worker, if it hasn't been updated yet).
if env_runner_states:
# Put the state dictionary into Ray's object store to avoid having to make n
# pickled copies of the state dict.
ref_env_runner_states = ray.put(env_runner_states)
def _update(_env_runner: EnvRunner) -> None:
_env_runner.set_state(ray.get(ref_env_runner_states))
# Broadcast updated states back to all workers.
self.foreach_worker(
_update,
remote_worker_ids=env_runner_indices_to_update,
local_env_runner=config.update_worker_filter_stats,
timeout_seconds=0.0, # This is a state update -> Fire-and-forget.
)
@DeveloperAPI
def sync_weights(
self,
policies: Optional[List[PolicyID]] = None,
from_worker_or_learner_group: Optional[Union[EnvRunner, "LearnerGroup"]] = None,
to_worker_indices: Optional[List[int]] = None,
global_vars: Optional[Dict[str, TensorType]] = None,
timeout_seconds: Optional[float] = 0.0,
inference_only: Optional[bool] = False,
) -> None:
"""Syncs model weights from the given weight source to all remote workers.
Weight source can be either a (local) rollout worker or a learner_group. It
should just implement a `get_weights` method.
Args:
policies: Optional list of PolicyIDs to sync weights for.
If None (default), sync weights to/from all policies.
from_worker_or_learner_group: Optional (local) EnvRunner instance or
LearnerGroup instance to sync from. If None (default),
sync from this EnvRunnerGroup's local worker.
to_worker_indices: Optional list of worker indices to sync the
weights to. If None (default), sync to all remote workers.
global_vars: An optional global vars dict to set this
worker to. If None, do not update the global_vars.
timeout_seconds: Timeout in seconds to wait for the sync weights
calls to complete. Default is 0.0 (fire-and-forget, do not wait
for any sync calls to finish). Setting this to 0.0 might significantly
improve algorithm performance, depending on the algo's `training_step`
logic.
inference_only: Sync weights with workers that keep inference-only
modules. This is needed for algorithms in the new stack that
use inference-only modules. In this case only a part of the
parameters are synced to the workers. Default is False.
"""
if self.local_env_runner is None and from_worker_or_learner_group is None:
raise TypeError(
"No `local_env_runner` in EnvRunnerGroup! Must provide "
"`from_worker_or_learner_group` arg in `sync_weights()`!"
)
# Only sync if we have remote workers or `from_worker_or_trainer` is provided.
rl_module_state = None
if self.num_remote_workers() or from_worker_or_learner_group is not None:
weights_src = from_worker_or_learner_group or self.local_env_runner
if weights_src is None:
raise ValueError(
"`from_worker_or_trainer` is None. In this case, EnvRunnerGroup "
"should have local_env_runner. But local_env_runner is also None."
)
modules = (
[COMPONENT_RL_MODULE + "/" + p for p in policies]
if policies is not None
else [COMPONENT_RL_MODULE]
)
# LearnerGroup has-a Learner has-a RLModule.
if isinstance(weights_src, LearnerGroup):
rl_module_state = weights_src.get_state(
components=[COMPONENT_LEARNER + "/" + m for m in modules],
inference_only=inference_only,
)[COMPONENT_LEARNER]
# EnvRunner has-a RLModule.
elif self._remote_config.enable_env_runner_and_connector_v2:
rl_module_state = weights_src.get_state(
components=modules,
inference_only=inference_only,
)
else:
rl_module_state = weights_src.get_weights(
policies=policies,
inference_only=inference_only,
)
if self._remote_config.enable_env_runner_and_connector_v2:
# Make sure `rl_module_state` only contains the weights and the
# weight seq no, nothing else.
rl_module_state = {
k: v
for k, v in rl_module_state.items()
if k in [COMPONENT_RL_MODULE, WEIGHTS_SEQ_NO]
}
# Move weights to the object store to avoid having to make n pickled
# copies of the weights dict for each worker.
rl_module_state_ref = ray.put(rl_module_state)
def _set_weights(env_runner):
env_runner.set_state(ray.get(rl_module_state_ref))
else:
rl_module_state_ref = ray.put(rl_module_state)
def _set_weights(env_runner):
env_runner.set_weights(ray.get(rl_module_state_ref), global_vars)
# Sync to specified remote workers in this EnvRunnerGroup.
self.foreach_worker(
func=_set_weights,
local_env_runner=False, # Do not sync back to local worker.
remote_worker_ids=to_worker_indices,
timeout_seconds=timeout_seconds,
)
# If `from_worker_or_learner_group` is provided, also sync to this
# EnvRunnerGroup's local worker.
if self.local_env_runner is not None:
if from_worker_or_learner_group is not None:
if self._remote_config.enable_env_runner_and_connector_v2:
self.local_env_runner.set_state(rl_module_state)
else:
self.local_env_runner.set_weights(rl_module_state)
# If `global_vars` is provided and local worker exists -> Update its
# global_vars.
if global_vars is not None:
self.local_env_runner.set_global_vars(global_vars)
@DeveloperAPI
def add_policy(
self,
policy_id: PolicyID,
policy_cls: Optional[Type[Policy]] = None,
policy: Optional[Policy] = None,
*,
observation_space: Optional[gym.spaces.Space] = None,
action_space: Optional[gym.spaces.Space] = None,
config: Optional[Union["AlgorithmConfig", PartialAlgorithmConfigDict]] = None,
policy_state: Optional[PolicyState] = None,
policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID], PolicyID]] = None,
policies_to_train: Optional[
Union[
Collection[PolicyID],
Callable[[PolicyID, Optional[SampleBatchType]], bool],
]
] = None,
module_spec: Optional[RLModuleSpec] = None,
# Deprecated.
workers: Optional[List[Union[EnvRunner, ActorHandle]]] = DEPRECATED_VALUE,
) -> None:
"""Adds a policy to this EnvRunnerGroup's workers or a specific list of workers.
Args:
policy_id: ID of the policy to add.
policy_cls: The Policy class to use for constructing the new Policy.
Note: Only one of `policy_cls` or `policy` must be provided.
policy: The Policy instance to add to this EnvRunnerGroup. If not None, the
given Policy object will be directly inserted into the
local worker and clones of that Policy will be created on all remote
workers.
Note: Only one of `policy_cls` or `policy` must be provided.
observation_space: The observation space of the policy to add.
If None, try to infer this space from the environment.
action_space: The action space of the policy to add.
If None, try to infer this space from the environment.
config: The config object or overrides for the policy to add.
policy_state: Optional state dict to apply to the new
policy instance, right after its construction.
policy_mapping_fn: An optional (updated) policy mapping function
to use from here on. Note that already ongoing episodes will
not change their mapping but will use the old mapping till
the end of the episode.
policies_to_train: An optional list of policy IDs to be trained
or a callable taking PolicyID and SampleBatchType and
returning a bool (trainable or not?).
If None, will keep the existing setup in place. Policies,
whose IDs are not in the list (or for which the callable
returns False) will not be updated.
module_spec: In the new RLModule API we need to pass in the module_spec for
the new module that is supposed to be added. Knowing the policy spec is
not sufficient.
workers: A list of EnvRunner/ActorHandles (remote
EnvRunners) to add this policy to. If defined, will only
add the given policy to these workers.
Raises:
KeyError: If the given `policy_id` already exists in this EnvRunnerGroup.
"""
if self.local_env_runner and policy_id in self.local_env_runner.policy_map:
raise KeyError(
f"Policy ID '{policy_id}' already exists in policy map! "
"Make sure you use a Policy ID that has not been taken yet."
" Policy IDs that are already in your policy map: "
f"{list(self.local_env_runner.policy_map.keys())}"
)
if workers is not DEPRECATED_VALUE:
deprecation_warning(
old="EnvRunnerGroup.add_policy(.., workers=..)",
help=(
"The `workers` argument to `EnvRunnerGroup.add_policy()` is "
"deprecated! Please do not use it anymore."
),
error=True,
)
if (policy_cls is None) == (policy is None):
raise ValueError(
"Only one of `policy_cls` or `policy` must be provided to "
"staticmethod: `EnvRunnerGroup.add_policy()`!"
)
validate_module_id(policy_id, error=False)
# Policy instance not provided: Use the information given here.
if policy_cls is not None:
new_policy_instance_kwargs = dict(
policy_id=policy_id,
policy_cls=policy_cls,
observation_space=observation_space,
action_space=action_space,
config=config,
policy_state=policy_state,
policy_mapping_fn=policy_mapping_fn,
policies_to_train=list(policies_to_train)
if policies_to_train
else None,
module_spec=module_spec,
)
# Policy instance provided: Create clones of this very policy on the different
# workers (copy all its properties here for the calls to add_policy on the
# remote workers).
else:
new_policy_instance_kwargs = dict(
policy_id=policy_id,
policy_cls=type(policy),
observation_space=policy.observation_space,
action_space=policy.action_space,
config=policy.config,
policy_state=policy.get_state(),
policy_mapping_fn=policy_mapping_fn,
policies_to_train=list(policies_to_train)
if policies_to_train
else None,
module_spec=module_spec,
)
def _create_new_policy_fn(worker):
# `foreach_worker` function: Adds the policy the the worker (and
# maybe changes its policy_mapping_fn - if provided here).
worker.add_policy(**new_policy_instance_kwargs)
if self.local_env_runner is not None:
# Add policy directly by (already instantiated) object.
if policy is not None:
self.local_env_runner.add_policy(
policy_id=policy_id,
policy=policy,
policy_mapping_fn=policy_mapping_fn,
policies_to_train=policies_to_train,
module_spec=module_spec,
)
# Add policy by constructor kwargs.
else:
self.local_env_runner.add_policy(**new_policy_instance_kwargs)
# Add the policy to all remote workers.
self.foreach_worker(_create_new_policy_fn, local_env_runner=False)
@DeveloperAPI
def add_workers(self, num_workers: int, validate: bool = False) -> None:
"""Creates and adds a number of remote workers to this worker set.
Can be called several times on the same EnvRunnerGroup to add more
EnvRunners to the set.
Args:
num_workers: The number of remote Workers to add to this
EnvRunnerGroup.
validate: Whether to validate remote workers after their construction
process.
Raises:
RayError: If any of the constructed remote workers is not up and running
properly.
"""
old_num_workers = self._worker_manager.num_actors()
new_workers = [
self._make_worker(
cls=self._cls,
env_creator=self._env_creator,
validate_env=None,
worker_index=old_num_workers + i + 1,
num_workers=old_num_workers + num_workers,
config=self._remote_config,
)
for i in range(num_workers)
]
self._worker_manager.add_actors(new_workers)
# Validate here, whether all remote workers have been constructed properly
# and are "up and running". Establish initial states.
if validate:
for result in self._worker_manager.foreach_actor(
lambda w: w.assert_healthy()
):
# Simiply raise the error, which will get handled by the try-except
# clause around the _setup().
if not result.ok:
e = result.get()
if self._ignore_ray_errors_on_env_runners:
logger.error(f"Validation of EnvRunner failed! Error={str(e)}")
else:
raise e
@DeveloperAPI
def reset(self, new_remote_workers: List[ActorHandle]) -> None:
"""Hard overrides the remote EnvRunners in this set with the provided ones.
Args:
new_remote_workers: A list of new EnvRunners (as `ActorHandles`) to use as
new remote workers.
"""
self._worker_manager.clear()
self._worker_manager.add_actors(new_remote_workers)
@DeveloperAPI
def stop(self) -> None:
"""Calls `stop` on all rollout workers (including the local one)."""
try:
# Make sure we stop all workers, include the ones that were just
# restarted / recovered or that are tagged unhealthy (at least, we should
# try).
self.foreach_worker(
lambda w: w.stop(), healthy_only=False, local_env_runner=True
)
except Exception:
logger.exception("Failed to stop workers!")
finally:
self._worker_manager.clear()
@DeveloperAPI
def is_policy_to_train(
self, policy_id: PolicyID, batch: Optional[SampleBatchType] = None
) -> bool:
"""Whether given PolicyID (optionally inside some batch) is trainable."""
if self.local_env_runner:
if self.local_env_runner.is_policy_to_train is None:
return True
return self.local_env_runner.is_policy_to_train(policy_id, batch)
else:
raise NotImplementedError
@DeveloperAPI
def foreach_worker(
self,
func: Callable[[EnvRunner], T],
*,
local_env_runner: bool = True,
healthy_only: bool = True,
remote_worker_ids: List[int] = None,
timeout_seconds: Optional[float] = None,
return_obj_refs: bool = False,
mark_healthy: bool = False,
# Deprecated args.
local_worker=DEPRECATED_VALUE,
) -> List[T]:
"""Calls the given function with each EnvRunner as its argument.
Args:
func: The function to call for each worker (as only arg).
local_env_runner: Whether to apply `func` to local EnvRunner, too.
Default is True.
healthy_only: Apply `func` on known-to-be healthy workers only.
remote_worker_ids: Apply `func` on a selected set of remote workers. Use
None (default) for all remote EnvRunners.
timeout_seconds: Time to wait (in seconds) for results. Set this to 0.0 for
fire-and-forget. Set this to None (default) to wait infinitely (i.e. for
synchronous execution).
return_obj_refs: whether to return ObjectRef instead of actual results.
Note, for fault tolerance reasons, these returned ObjectRefs should
never be resolved with ray.get() outside of this EnvRunnerGroup.
mark_healthy: Whether to mark all those workers healthy again that are
currently marked unhealthy AND that returned results from the remote
call (within the given `timeout_seconds`).
Note that workers are NOT set unhealthy, if they simply time out
(only if they return a RayActorError).
Also note that this setting is ignored if `healthy_only=True` (b/c
`mark_healthy` only affects workers that are currently tagged as
unhealthy).
Returns:
The list of return values of all calls to `func([worker])`.
"""
if local_worker != DEPRECATED_VALUE:
deprecation_warning(
old="foreach_worker(local_worker=..)",
new="foreach_worker(local_env_runner=..)",
error=True,
)
assert (
not return_obj_refs or not local_env_runner
), "Can not return ObjectRef from local worker."
local_result = []
if local_env_runner and self.local_env_runner is not None:
local_result = [func(self.local_env_runner)]
if not self._worker_manager.actor_ids():
return local_result
remote_results = self._worker_manager.foreach_actor(
func,
healthy_only=healthy_only,
remote_actor_ids=remote_worker_ids,
timeout_seconds=timeout_seconds,
return_obj_refs=return_obj_refs,
mark_healthy=mark_healthy,
)
FaultTolerantActorManager.handle_remote_call_result_errors(
remote_results, ignore_ray_errors=self._ignore_ray_errors_on_env_runners
)
# With application errors handled, return good results.
remote_results = [r.get() for r in remote_results.ignore_errors()]
return local_result + remote_results
@DeveloperAPI
def foreach_worker_with_id(
self,
func: Callable[[int, EnvRunner], T],
*,
local_env_runner: bool = True,
healthy_only: bool = True,
remote_worker_ids: List[int] = None,
timeout_seconds: Optional[float] = None,
return_obj_refs: bool = False,
mark_healthy: bool = False,
# Deprecated args.
local_worker=DEPRECATED_VALUE,
) -> List[T]:
"""Calls the given function with each EnvRunner and its ID as its arguments.
Args:
func: The function to call for each worker (as only arg).
local_env_runner: Whether to apply `func` tn local worker, too.
Default is True.
healthy_only: Apply `func` on known-to-be healthy workers only.
remote_worker_ids: Apply `func` on a selected set of remote workers.
timeout_seconds: Time to wait for results. Default is None.
return_obj_refs: whether to return ObjectRef instead of actual results.
Note, for fault tolerance reasons, these returned ObjectRefs should
never be resolved with ray.get() outside of this EnvRunnerGroup.
mark_healthy: Whether to mark all those workers healthy again that are
currently marked unhealthy AND that returned results from the remote
call (within the given `timeout_seconds`).
Note that workers are NOT set unhealthy, if they simply time out
(only if they return a RayActorError).
Also note that this setting is ignored if `healthy_only=True` (b/c
`mark_healthy` only affects workers that are currently tagged as
unhealthy).
Returns:
The list of return values of all calls to `func([worker, id])`.
"""
if local_worker != DEPRECATED_VALUE:
deprecation_warning(
old="foreach_worker_with_id(local_worker=...)",
new="foreach_worker_with_id(local_env_runner=...)",
error=True,
)
local_result = []
if local_env_runner and self.local_env_runner is not None:
local_result = [func(0, self.local_env_runner)]
if not remote_worker_ids:
remote_worker_ids = self._worker_manager.actor_ids()
funcs = [functools.partial(func, i) for i in remote_worker_ids]
remote_results = self._worker_manager.foreach_actor(
funcs,
healthy_only=healthy_only,
remote_actor_ids=remote_worker_ids,
timeout_seconds=timeout_seconds,
return_obj_refs=return_obj_refs,
mark_healthy=mark_healthy,
)
FaultTolerantActorManager.handle_remote_call_result_errors(
remote_results,
ignore_ray_errors=self._ignore_ray_errors_on_env_runners,
)
remote_results = [r.get() for r in remote_results.ignore_errors()]
return local_result + remote_results
@DeveloperAPI
def foreach_worker_async(
self,
func: Callable[[EnvRunner], T],
*,
healthy_only: bool = True,