Skip to content

Commit

Permalink
[RLlib; ConnectorV2] Minor (backward compatible) ConnectorV2 API chan…
Browse files Browse the repository at this point in the history
…ges (in preparation for upcoming ConnectorV2 docs). (ray-project#47334)
  • Loading branch information
sven1977 authored Aug 27, 2024
1 parent 78402bc commit 70f7083
Show file tree
Hide file tree
Showing 33 changed files with 409 additions and 271 deletions.
2 changes: 1 addition & 1 deletion rllib/algorithms/impala/impala_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def update_from_episodes(
with self.metrics.log_time((ALL_MODULES, EPISODES_TO_BATCH_TIMER)):
batch = self._learner_connector(
rl_module=self.module,
data={},
batch={},
episodes=episodes,
shared_data={},
)
Expand Down
4 changes: 2 additions & 2 deletions rllib/algorithms/marwil/marwil_offline_prelearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, MultiAgentBatch]:
# Run the `Learner`'s connector pipeline.
batch = self._learner_connector(
rl_module=self._module,
data=batch,
batch=batch,
episodes=episodes,
shared_data={},
)
Expand Down Expand Up @@ -116,7 +116,7 @@ def _compute_gae_from_episodes(
# bootstrapped vf) computations.
batch_for_vf = self._learner_connector(
rl_module=self._module,
data={},
batch={},
episodes=episodes,
shared_data={},
)
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/ppo/ppo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _compute_gae_from_episodes(
# bootstrapped vf) computations.
batch_for_vf = self._learner_connector(
rl_module=self.module,
data={},
batch={},
episodes=episodes,
shared_data={},
)
Expand Down
20 changes: 10 additions & 10 deletions rllib/connectors/common/add_observations_from_episodes_to_batch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional

import gymnasium as gym

Expand Down Expand Up @@ -78,17 +78,17 @@ class AddObservationsFromEpisodesToBatch(ConnectorV2):
# Call the connector with the two created episodes.
# Note that this particular connector works without an RLModule, so we
# simplify here for the sake of this example.
output_data = connector(
output_batch = connector(
rl_module=None,
data={},
batch={},
episodes=episodes,
explore=True,
shared_data={},
)
# The output data should now contain the last observations of both episodes,
# in a "per-episode organized" fashion.
check(
output_data,
output_batch,
{
"obs": {
(episodes[0].id_,): [eps_1_last_obs],
Expand Down Expand Up @@ -127,15 +127,15 @@ def __call__(
self,
*,
rl_module: RLModule,
data: Optional[Any],
batch: Dict[str, Any],
episodes: List[EpisodeType],
explore: Optional[bool] = None,
shared_data: Optional[dict] = None,
**kwargs,
) -> Any:
# If "obs" already in data, early out.
if Columns.OBS in data:
return data
if Columns.OBS in batch:
return batch

for sa_episode in self.single_agent_episode_iterator(
episodes,
Expand All @@ -146,7 +146,7 @@ def __call__(
):
if self._as_learner_connector:
self.add_n_batch_items(
data,
batch,
Columns.OBS,
items_to_add=sa_episode.get_observations(slice(0, len(sa_episode))),
num_items=len(sa_episode),
Expand All @@ -155,9 +155,9 @@ def __call__(
else:
assert not sa_episode.is_finalized
self.add_batch_item(
data,
batch,
Columns.OBS,
item_to_add=sa_episode.get_observations(-1),
single_agent_episode=sa_episode,
)
return data
return batch
48 changes: 24 additions & 24 deletions rllib/connectors/common/add_states_from_episodes_to_batch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import deque
import math
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional

import gymnasium as gym
import numpy as np
Expand All @@ -13,7 +13,7 @@
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.utils.annotations import override
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.spaces.space_utils import batch, BatchedNdArray
from ray.rllib.utils.spaces.space_utils import batch as batch_fn, BatchedNdArray
from ray.rllib.utils.typing import EpisodeType
from ray.util.annotations import PublicAPI

Expand Down Expand Up @@ -95,17 +95,17 @@ def get_initial_state(self):
connector = AddStatesFromEpisodesToBatch(as_learner_connector=False)
# Call the connector.
output_data = connector(
output_batch = connector(
rl_module=rl_module,
data={},
batch={},
episodes=[episode],
shared_data={},
)
# The output data's STATE_IN key should now contain the RLModule's initial state
# plus the one state out found in the episode in a "per-episode organized"
# fashion.
check(
output_data[Columns.STATE_IN],
output_batch[Columns.STATE_IN],
{
(episode.id_,): [rl_module_init_state],
},
Expand All @@ -127,17 +127,17 @@ def get_initial_state(self):
)
# Call the connector.
output_data = connector(
output_batch = connector(
rl_module=rl_module,
data={},
batch={},
episodes=[episode],
shared_data={},
)
# The output data's STATE_IN key should now contain the episode's last
# STATE_OUT, NOT the RLModule's initial state in a "per-episode organized"
# fashion.
check(
output_data[Columns.STATE_IN],
output_batch[Columns.STATE_IN],
{
# Expect the episode's last STATE_OUT.
(episode.id_,): [-1.0],
Expand All @@ -154,14 +154,14 @@ def get_initial_state(self):
)
# Call the connector.
output_data = connector(
output_batch = connector(
rl_module=rl_module,
data={},
batch={},
episodes=[episode.finalize()],
shared_data={},
)
check(
output_data[Columns.STATE_IN],
output_batch[Columns.STATE_IN],
{
# Expect initial module state + every 2nd STATE_OUT from episode, but
# not the very last one (just like the very last observation, this data
Expand Down Expand Up @@ -211,15 +211,15 @@ def __call__(
self,
*,
rl_module: RLModule,
data: Optional[Any],
batch: Dict[str, Any],
episodes: List[EpisodeType],
explore: Optional[bool] = None,
shared_data: Optional[dict] = None,
**kwargs,
) -> Any:
# If not stateful OR STATE_IN already in data, early out.
if not rl_module.is_stateful() or Columns.STATE_IN in data:
return data
if not rl_module.is_stateful() or Columns.STATE_IN in batch:
return batch

# Make all inputs (other than STATE_IN) have an additional T-axis.
# Since data has not been batched yet (we are still operating on lists in the
Expand All @@ -228,9 +228,9 @@ def __call__(
# Also, let module-to-env pipeline know that we had added a single timestep
# time rank to the data (to remove it again).
if not self._as_learner_connector:
for column in data.keys():
for column in batch.keys():
self.foreach_batch_item_change_in_place(
batch=data,
batch=batch,
column=column,
func=lambda item, eps_id, aid, mid: (
item
Expand All @@ -249,7 +249,7 @@ def __call__(
else:
# Before adding STATE_IN to the `data`, zero-pad existing data and batch
# into max_seq_len chunks.
for column, column_data in data.copy().items():
for column, column_data in batch.copy().items():
# Do not zero-pad INFOS column.
if column == Columns.INFOS:
continue
Expand Down Expand Up @@ -317,7 +317,7 @@ def __call__(
# state_outs.shape=(T,[state-dim]) T=episode len
state_outs = sa_episode.get_extra_model_outputs(key=Columns.STATE_OUT)
self.add_n_batch_items(
batch=data,
batch=batch,
column=Columns.STATE_IN,
# items_to_add.shape=(B,[state-dim]) # B=episode len // max_seq_len
items_to_add=tree.map_structure(
Expand All @@ -340,15 +340,15 @@ def __call__(
len(sa_episode), self.max_seq_len
)
self.add_n_batch_items(
batch=data,
batch=batch,
column=Columns.SEQ_LENS,
items_to_add=seq_lens,
num_items=len(seq_lens),
single_agent_episode=sa_episode,
)
if not shared_data.get("_added_loss_mask_for_valid_episode_ts"):
self.add_n_batch_items(
batch=data,
batch=batch,
column=Columns.LOSS_MASK,
items_to_add=mask,
num_items=len(mask),
Expand All @@ -375,13 +375,13 @@ def __call__(
key=Columns.STATE_OUT, indices=-1
)
self.add_batch_item(
data,
batch,
Columns.STATE_IN,
item_to_add=state,
single_agent_episode=sa_episode,
)

return data
return batch


def split_and_zero_pad_list(item_list, T: int):
Expand Down Expand Up @@ -417,7 +417,7 @@ def split_and_zero_pad_list(item_list, T: int):

if current_t == T:
ret.append(
batch(
batch_fn(
current_time_row,
individual_items_already_have_batch_dim="auto",
)
Expand All @@ -428,7 +428,7 @@ def split_and_zero_pad_list(item_list, T: int):
if current_t > 0 and current_t < T:
current_time_row.extend([zero_element] * (T - current_t))
ret.append(
batch(current_time_row, individual_items_already_have_batch_dim="auto")
batch_fn(current_time_row, individual_items_already_have_batch_dim="auto")
)

return ret
Expand Down
30 changes: 19 additions & 11 deletions rllib/connectors/common/agent_to_module_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,17 +105,17 @@ class AgentToModuleMapping(ConnectorV2):
# Call the connector (and thereby flip from AgentID based to ModuleID based
# structure..
output_data = connector(
output_batch = connector(
rl_module=None, # This particular connector works without an RLModule.
data=batch,
batch=batch,
episodes=[], # This particular connector works without a list of episodes.
explore=True,
shared_data={},
)
# `data` should now be mapped from ModuleIDs to module data.
check(
output_data,
output_batch,
{
"module0": {
"obs": [0, 1, 2],
Expand All @@ -130,12 +130,20 @@ class AgentToModuleMapping(ConnectorV2):
"""

@override(ConnectorV2)
def recompute_observation_space_from_input_spaces(self):
return self._map_space_if_necessary(self.input_observation_space, "obs")
def recompute_output_observation_space(
self,
input_observation_space: gym.Space,
input_action_space: gym.Space,
) -> gym.Space:
return self._map_space_if_necessary(input_observation_space, "obs")

@override(ConnectorV2)
def recompute_action_space_from_input_spaces(self):
return self._map_space_if_necessary(self.input_action_space, "act")
def recompute_output_action_space(
self,
input_observation_space: gym.Space,
input_action_space: gym.Space,
) -> gym.Space:
return self._map_space_if_necessary(input_action_space, "act")

def __init__(
self,
Expand All @@ -155,7 +163,7 @@ def __call__(
self,
*,
rl_module: RLModule,
data: Optional[Any],
batch: Dict[str, Any],
episodes: List[EpisodeType],
explore: Optional[bool] = None,
shared_data: Optional[dict] = None,
Expand All @@ -166,7 +174,7 @@ def __call__(
# Store in shared data, which module IDs map to which episode/agent, such
# that the module-to-env pipeline can map the data back to agents.
memorized_map_structure = defaultdict(list)
for column, agent_data in data.items():
for column, agent_data in batch.items():
if rl_module is not None and column in rl_module:
continue
for eps_id, agent_id, module_id in agent_data.keys():
Expand All @@ -180,7 +188,7 @@ def __call__(
data_by_module = {}

# Iterating over each column in the original data:
for column, agent_data in data.items():
for column, agent_data in batch.items():
if rl_module is not None and column in rl_module:
if column in data_by_module:
data_by_module[column].update(agent_data)
Expand All @@ -204,7 +212,7 @@ def __call__(

return data_by_module

def _map_space_if_necessary(self, space, which: str = "obs"):
def _map_space_if_necessary(self, space: gym.Space, which: str = "obs"):
# Analyze input observation space to check, whether the user has already taken
# care of the agent to module mapping.
if set(self._module_specs) == set(space.spaces.keys()):
Expand Down
Loading

0 comments on commit 70f7083

Please sign in to comment.