forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
util.py
170 lines (138 loc) · 6.05 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import logging
from typing import Any, Tuple, TYPE_CHECKING
from ray.rllib.connectors.action.clip import ClipActionsConnector
from ray.rllib.connectors.action.immutable import ImmutableActionsConnector
from ray.rllib.connectors.action.lambdas import ConvertToNumpyConnector
from ray.rllib.connectors.action.normalize import NormalizeActionsConnector
from ray.rllib.connectors.action.pipeline import ActionConnectorPipeline
from ray.rllib.connectors.agent.clip_reward import ClipRewardAgentConnector
from ray.rllib.connectors.agent.obs_preproc import ObsPreprocessorConnector
from ray.rllib.connectors.agent.pipeline import AgentConnectorPipeline
from ray.rllib.connectors.agent.state_buffer import StateBufferConnector
from ray.rllib.connectors.agent.view_requirement import ViewRequirementAgentConnector
from ray.rllib.connectors.connector import Connector, ConnectorContext
from ray.rllib.connectors.registry import get_connector
from ray.rllib.connectors.agent.mean_std_filter import (
MeanStdObservationFilterAgentConnector,
ConcurrentMeanStdObservationFilterAgentConnector,
)
from ray.rllib.utils.annotations import OldAPIStack
from ray.rllib.connectors.agent.synced_filter import SyncedFilterAgentConnector
if TYPE_CHECKING:
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.policy.policy import Policy
logger = logging.getLogger(__name__)
def __preprocessing_enabled(config: "AlgorithmConfig"):
if config._disable_preprocessor_api:
return False
# Same conditions as in RolloutWorker.__init__.
if config.is_atari and config.preprocessor_pref == "deepmind":
return False
if config.preprocessor_pref is None:
return False
return True
def __clip_rewards(config: "AlgorithmConfig"):
# Same logic as in RolloutWorker.__init__.
# We always clip rewards for Atari games.
return config.clip_rewards or config.is_atari
@OldAPIStack
def get_agent_connectors_from_config(
ctx: ConnectorContext,
config: "AlgorithmConfig",
) -> AgentConnectorPipeline:
connectors = []
clip_rewards = __clip_rewards(config)
if clip_rewards is True:
connectors.append(ClipRewardAgentConnector(ctx, sign=True))
elif type(clip_rewards) == float:
connectors.append(ClipRewardAgentConnector(ctx, limit=abs(clip_rewards)))
if __preprocessing_enabled(config):
connectors.append(ObsPreprocessorConnector(ctx))
# Filters should be after observation preprocessing
filter_connector = get_synced_filter_connector(
ctx,
)
# Configuration option "NoFilter" results in `filter_connector==None`.
if filter_connector:
connectors.append(filter_connector)
connectors.extend(
[
StateBufferConnector(ctx),
ViewRequirementAgentConnector(ctx),
]
)
return AgentConnectorPipeline(ctx, connectors)
@OldAPIStack
def get_action_connectors_from_config(
ctx: ConnectorContext,
config: "AlgorithmConfig",
) -> ActionConnectorPipeline:
"""Default list of action connectors to use for a new policy.
Args:
ctx: context used to create connectors.
config: The AlgorithmConfig object.
"""
connectors = [ConvertToNumpyConnector(ctx)]
if config.get("normalize_actions", False):
connectors.append(NormalizeActionsConnector(ctx))
if config.get("clip_actions", False):
connectors.append(ClipActionsConnector(ctx))
connectors.append(ImmutableActionsConnector(ctx))
return ActionConnectorPipeline(ctx, connectors)
@OldAPIStack
def create_connectors_for_policy(policy: "Policy", config: "AlgorithmConfig"):
"""Util to create agent and action connectors for a Policy.
Args:
policy: Policy instance.
config: Algorithm config dict.
"""
ctx: ConnectorContext = ConnectorContext.from_policy(policy)
assert (
policy.agent_connectors is None and policy.action_connectors is None
), "Can not create connectors for a policy that already has connectors."
policy.agent_connectors = get_agent_connectors_from_config(ctx, config)
policy.action_connectors = get_action_connectors_from_config(ctx, config)
logger.info("Using connectors:")
logger.info(policy.agent_connectors.__str__(indentation=4))
logger.info(policy.action_connectors.__str__(indentation=4))
@OldAPIStack
def restore_connectors_for_policy(
policy: "Policy", connector_config: Tuple[str, Tuple[Any]]
) -> Connector:
"""Util to create connector for a Policy based on serialized config.
Args:
policy: Policy instance.
connector_config: Serialized connector config.
"""
ctx: ConnectorContext = ConnectorContext.from_policy(policy)
name, params = connector_config
return get_connector(name, ctx, params)
# We need this filter selection mechanism temporarily to remain compatible to old API
@OldAPIStack
def get_synced_filter_connector(ctx: ConnectorContext):
filter_specifier = ctx.config.get("observation_filter")
if filter_specifier == "MeanStdFilter":
return MeanStdObservationFilterAgentConnector(ctx, clip=None)
elif filter_specifier == "ConcurrentMeanStdFilter":
return ConcurrentMeanStdObservationFilterAgentConnector(ctx, clip=None)
elif filter_specifier == "NoFilter":
return None
else:
raise Exception("Unknown observation_filter: " + str(filter_specifier))
@OldAPIStack
def maybe_get_filters_for_syncing(rollout_worker, policy_id):
# As long as the historic filter synchronization mechanism is in
# place, we need to put filters into self.filters so that they get
# synchronized
policy = rollout_worker.policy_map[policy_id]
if not policy.agent_connectors:
return
filter_connectors = policy.agent_connectors[SyncedFilterAgentConnector]
# There can only be one filter at a time
if not filter_connectors:
return
assert len(filter_connectors) == 1, (
"ConnectorPipeline has multiple connectors of type "
"SyncedFilterAgentConnector but can only have one."
)
rollout_worker.filters[policy_id] = filter_connectors[0].filter