Skip to content

Commit

Permalink
[RLlib] Introduce Catalog skeleton for RLModules (ray-project#32069)
Browse files Browse the repository at this point in the history
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
  • Loading branch information
ArturNiederfahrenhorst authored Feb 27, 2023
1 parent 3a07221 commit 8624273
Show file tree
Hide file tree
Showing 18 changed files with 791 additions and 411 deletions.
7 changes: 7 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1845,6 +1845,13 @@ py_test(
# Tag: core
# --------------------------------------------------------------------

py_test(
name = "test_catalog",
tags = ["team:rllib", "core"],
size = "small",
srcs = ["core/models/tests/test_catalog.py"]
)

py_test(
name = "test_torch_rl_module",
tags = ["team:rllib", "core"],
Expand Down
61 changes: 61 additions & 0 deletions rllib/algorithms/ppo/ppo_base_rl_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
This file holds framework-agnostic components for PPO's RLModules.
"""

import abc
from typing import Mapping, Any

import gymnasium as gym

from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
from ray.rllib.algorithms.ppo.ppo_rl_module_config import PPOModuleConfig
from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleConfig
from ray.rllib.utils.annotations import override, ExperimentalAPI
from ray.rllib.utils.gym import convert_old_gym_space_to_gymnasium_space
from ray.rllib.core.models.base import ActorCriticEncoder


@ExperimentalAPI
class PPORLModuleBase(RLModule, abc.ABC):
framework = None

def __init__(self, config: RLModuleConfig):
super().__init__()
self.config = config
catalog = config.catalog

assert isinstance(catalog, PPOCatalog), "A PPOCatalog is required for PPO."

# Build models from catalog
self.encoder = catalog.build_actor_critic_encoder(framework=self.framework)
self.pi = catalog.build_pi_head(framework=self.framework)
self.vf = catalog.build_vf_head(framework=self.framework)

self._is_discrete = isinstance(
convert_old_gym_space_to_gymnasium_space(self.config.action_space),
gym.spaces.Discrete,
)
assert isinstance(self.encoder, ActorCriticEncoder)

@classmethod
@override(RLModule)
def from_model_config(
cls,
observation_space: gym.Space,
action_space: gym.Space,
*,
model_config_dict: Mapping[str, Any],
) -> "PPORLModuleBase":
catalog = PPOCatalog(
observation_space=observation_space,
action_space=action_space,
model_config_dict=model_config_dict,
)

config = PPOModuleConfig(
observation_space=observation_space,
action_space=action_space,
catalog=catalog,
)

return config.build(framework=cls.framework)
155 changes: 155 additions & 0 deletions rllib/algorithms/ppo/ppo_catalog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import gymnasium as gym

from ray.rllib.core.models.catalog import Catalog
from ray.rllib.core.models.configs import ActorCriticEncoderConfig, MLPHeadConfig
from ray.rllib.utils import override


class PPOCatalog(Catalog):
"""The Catalog class used to build models for PPO.
PPOCatalog provides the following models:
- ActorCriticEncoder: The encoder used to encode the observations.
- Pi Head: The head used to compute the policy logits.
- Value Function Head: The head used to compute the value function.
The ActorCriticEncoder is a wrapper around Encoders to produce separate outputs
for the policy and value function. See implementations of PPORLModuleBase for
more details.
Any custom ActorCriticEncoder can be built by overriding the
build_actor_critic_encoder() method. Alternatively, the ActorCriticEncoderConfig
at PPOCatalog.actor_critic_encoder_config can be overridden to build a custom
ActorCriticEncoder during RLModule runtime.
Any custom head can be built by overriding the build_pi_head() and build_vf_head()
methods. Alternatively, the PiHeadConfig and VfHeadConfig can be overridden to
build custom heads during RLModule runtime.
"""

def __init__(
self,
observation_space: gym.Space,
action_space: gym.Space,
model_config_dict: dict,
):
"""Initializes the PPOCatalog.
Args:
observation_space: The observation space of the Encoder.
action_space: The action space for the Pi Head.
model_config_dict: The model config to use.
"""
super().__init__(
observation_space=observation_space,
action_space=action_space,
model_config_dict=model_config_dict,
)
free_log_std = model_config_dict.get("free_log_std")
assert not free_log_std, "free_log_std not supported yet."

assert isinstance(
observation_space, gym.spaces.Box
), "This simple PPO Module only supports Box observation space."

assert len(observation_space.shape) in (
1,
), "This simple PPO Module only supports 1D observation spaces."

assert isinstance(action_space, (gym.spaces.Discrete, gym.spaces.Box)), (
"This simple PPO Module only supports Discrete and Box action spaces.",
)

# Replace EncoderConfig by ActorCriticEncoderConfig
self.actor_critic_encoder_config = ActorCriticEncoderConfig(
base_encoder_config=self.encoder_config,
shared=self.model_config_dict["vf_share_layers"],
)

if isinstance(action_space, gym.spaces.Discrete):
pi_output_dim = action_space.n
else:
pi_output_dim = action_space.shape[0] * 2

post_fcnet_hiddens = self.model_config_dict["post_fcnet_hiddens"]
post_fcnet_activation = self.model_config_dict["post_fcnet_activation"]

self.pi_head_config = MLPHeadConfig(
input_dim=self.encoder_config.output_dim,
hidden_layer_dims=post_fcnet_hiddens,
hidden_layer_activation=post_fcnet_activation,
output_activation="linear",
output_dim=pi_output_dim,
)

self.vf_head_config = MLPHeadConfig(
input_dim=self.encoder_config.output_dim,
hidden_layer_dims=post_fcnet_hiddens,
hidden_layer_activation=post_fcnet_activation,
output_activation="linear",
output_dim=1,
)

# Set input- and output dimensions to fit PPO's needs.
self.encoder_config.input_dim = observation_space.shape[0]
self.pi_head_config.input_dim = self.encoder_config.output_dim
if isinstance(action_space, gym.spaces.Discrete):
self.pi_head_config.output_dim = int(action_space.n)
else:
self.pi_head_config.output_dim = int(action_space.shape[0] * 2)
self.vf_head_config.output_dim = 1

def build_actor_critic_encoder(self, framework: str):
"""Builds the ActorCriticEncoder.
The default behavior is to build the encoder from the encoder_config.
This can be overridden to build a custom ActorCriticEncoder as a means of
configuring the behavior of a PPORLModuleBase implementation.
Args:
framework: The framework to use. Either "torch" or "tf".
Returns:
The ActorCriticEncoder.
"""
return self.actor_critic_encoder_config.build(framework=framework)

@override(Catalog)
def build_encoder(self, framework: str):
"""Builds the encoder.
Since PPO uses an ActorCriticEncoder, this method should not be implemented.
"""
raise NotImplementedError(
"Use PPOCatalog.build_actor_critic_encoder() instead."
)

def build_pi_head(self, framework: str):
"""Builds the policy head.
The default behavior is to build the head from the pi_head_config.
This can be overridden to build a custom policy head as a means of configuring
the behavior of a PPORLModuleBase implementation.
Args:
framework: The framework to use. Either "torch" or "tf".
Returns:
The policy head.
"""
return self.pi_head_config.build(framework=framework)

def build_vf_head(self, framework: str):
"""Builds the value function head.
The default behavior is to build the head from the vf_head_config.
This can be overridden to build a custom value function head as a means of
configuring the behavior of a PPORLModuleBase implementation.
Args:
framework: The framework to use. Either "torch" or "tf".
Returns:
The value function head.
"""
return self.vf_head_config.build(framework=framework)
43 changes: 43 additions & 0 deletions rllib/algorithms/ppo/ppo_rl_module_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from dataclasses import dataclass

import gymnasium as gym

from ray.rllib.core.models.catalog import Catalog
from ray.rllib.core.rl_module.rl_module import RLModuleConfig
from ray.rllib.utils.annotations import ExperimentalAPI


@ExperimentalAPI
@dataclass
class PPOModuleConfig(RLModuleConfig):
"""Configuration for the PPORLModule.
Attributes:
observation_space: The observation space of the environment.
action_space: The action space of the environment.
catalog: The PPOCatalog object to use for building the models.
"""

observation_space: gym.Space = None
action_space: gym.Space = None
catalog: Catalog = None

def build(self, framework: str):
"""Builds a PPORLModule.
Args:
framework: The framework to use for the module.
Returns:
PPORLModule: The module.
"""
if framework == "torch":
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import (
PPOTorchRLModule,
)

return PPOTorchRLModule(self)
else:
from ray.rllib.algorithms.ppo.tf.ppo_tf_rl_module import PPOTfRLModule

return PPOTfRLModule(self)
Loading

0 comments on commit 8624273

Please sign in to comment.