Skip to content

Commit

Permalink
[RLlib] Test if we can modify Catalog and handle Specs defaults in Al…
Browse files Browse the repository at this point in the history
…gorithmConfig (ray-project#33128)

* Move input_dims and output_dims and minor cleanups
* Be able to pass in any Catalog
* Make encoder config part of __post_init__
* fix more occurences of input_dim and output_dim
* Update get_action_dist method to reflect output_dim changes

Signed-off-by: Artur Niederfahrenhorst <[email protected]>
  • Loading branch information
ArturNiederfahrenhorst authored Mar 9, 2023
1 parent 73685de commit 753ce33
Show file tree
Hide file tree
Showing 9 changed files with 303 additions and 30 deletions.
2 changes: 1 addition & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1848,7 +1848,7 @@ py_test(
py_test(
name = "test_catalog",
tags = ["team:rllib", "core"],
size = "small",
size = "medium",
srcs = ["core/models/tests/test_catalog.py"]
)

Expand Down
46 changes: 32 additions & 14 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,15 @@ def _resolve_class_path(module) -> Type:
return getattr(module, class_name)


def _check_rl_module_spec(module_spec: ModuleSpec) -> None:
if not isinstance(module_spec, (SingleAgentRLModuleSpec, MultiAgentRLModuleSpec)):
raise ValueError(
"rl_module_spec must be an instance of "
"SingleAgentRLModuleSpec or MultiAgentRLModuleSpec."
f"Got {type(module_spec)} instead."
)


class AlgorithmConfig(_Config):
"""A RLlib AlgorithmConfig builds an RLlib Algorithm from a given configuration.
Expand Down Expand Up @@ -901,17 +910,26 @@ def validate(self) -> None:
# compatibility for now. User only needs to set num_rollout_workers.
self.input_config["parallelism"] = self.num_rollout_workers or 1

# resolve rl_module_spec class
if self._enable_rl_module_api and self.rl_module_spec is None:
self.rl_module_spec = self.get_default_rl_module_spec()
if not isinstance(
self.rl_module_spec, (SingleAgentRLModuleSpec, MultiAgentRLModuleSpec)
):
raise ValueError(
"rl_module_spec must be an instance of "
"SingleAgentRLModuleSpec or MultiAgentRLModuleSpec."
f"Got {type(self.rl_module_spec)} instead."
)
if self._enable_rl_module_api:
default_rl_module_spec = self.get_default_rl_module_spec()
_check_rl_module_spec(default_rl_module_spec)

if self.rl_module_spec is not None:
# Merge provided RL Module spec class with defaults
_check_rl_module_spec(self.rl_module_spec)
# We can only merge if we have SingleAgentRLModuleSpecs.
# TODO(Artur): Support merging for MultiAgentRLModuleSpecs.
if isinstance(self.rl_module_spec, SingleAgentRLModuleSpec):
if isinstance(default_rl_module_spec, SingleAgentRLModuleSpec):
default_rl_module_spec.update(self.rl_module_spec)
self.rl_module_spec = default_rl_module_spec
elif isinstance(default_rl_module_spec, MultiAgentRLModuleSpec):
raise ValueError(
"Cannot merge MultiAgentRLModuleSpec with "
"SingleAgentRLModuleSpec!"
)
else:
self.rl_module_spec = default_rl_module_spec

# make sure the resource requirements for learner_group is valid
if self.num_learner_workers == 0 and self.num_gpus_per_worker > 1:
Expand Down Expand Up @@ -2283,9 +2301,9 @@ def rl_module(
Args:
rl_module_spec: The RLModule spec to use for this config. It can be either
a SingleAgentRLModuleSpec or a MultiAgentRLModuleSpec. If the
observation_space, action_space, or the model_config is not specified
it will be inferred from the env and other parts of the algorithm
config object.
observation_space, action_space, catalog_class, or the model_config is
not specified it will be inferred from the env and other parts of the
algorithm config object.
_enable_rl_module_api: Whether to enable the RLModule API for this config.
By default if you call `config.rl_module(...)`, the
RLModule API will NOT be enabled. If you want to enable it, you can call
Expand Down
4 changes: 2 additions & 2 deletions rllib/core/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ class ModelConfig(abc.ABC):
output_dims: The output dimensions of the network.
"""

input_dims: Union[List, Tuple] = None
output_dims: Union[List, Tuple] = None
input_dims: Union[List[int], Tuple[int]] = None
output_dims: Union[List[int], Tuple[int]] = None

@abc.abstractmethod
def build(self, framework: str):
Expand Down
5 changes: 3 additions & 2 deletions rllib/core/models/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def latent_dims(self):
This establishes an agreement between encoder and heads about the latent
dimensions. Encoders can be built to output a latent tensor with
`latent_dims` dimensions, and heads can be built with tensors of
`latent_dims` dimensions as inputs.
`latent_dims` dimensions as inputs. This can be safely ignored if this
agreement is not needed in case of modifications to the Catalog.
Returns:
The latent dimensions of the encoder.
Expand Down Expand Up @@ -159,7 +160,7 @@ def get_action_dist_cls(self, framework: str):
"""Get the action distribution class.
The default behavior is to get the action distribution from the
`Catalog.action_dist_cls_dict`. This can be overridden to build a custom action
`Catalog.action_dist_class_fn`. This can be overridden to build a custom action
distribution as a means of configuring the behavior of a PPORLModuleBase
implementation.
Expand Down
109 changes: 102 additions & 7 deletions rllib/core/models/tests/test_catalog.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
import itertools
import unittest
import functools
from collections import namedtuple

import gymnasium as gym
import numpy as np
import tree
from gymnasium.spaces import Box, Discrete
from collections import namedtuple

from ray.rllib.algorithms.ppo.ppo import PPOConfig
from ray.rllib.core.models.torch.base import TorchModel
from ray.rllib.core.models.base import ModelConfig, Encoder
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
from ray.rllib.core.models.base import STATE_IN, ENCODER_OUT, STATE_OUT
from ray.rllib.core.models.catalog import Catalog
from ray.rllib.core.models.configs import MLPEncoderConfig, CNNEncoderConfig
from ray.rllib.core.models.catalog import Catalog
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.models.tf.tf_distributions import (
TfCategorical,
Expand Down Expand Up @@ -64,9 +71,9 @@ def _check_model_outputs(self, model, framework, model_config_dict, input_space)
}
outputs = model(inputs)

assert outputs[ENCODER_OUT].shape == (32, latent_dim)
self.assertEqual(outputs[ENCODER_OUT].shape, (32, latent_dim))
tree.map_structure_with_path(
lambda p, v: self.assertTrue(v.shape == states[p].shape),
lambda p, v: self.assertEqual(v.shape, states[p].shape),
outputs[STATE_OUT],
)

Expand Down Expand Up @@ -177,7 +184,7 @@ def test_get_encoder_config(self):
model_config = catalog.get_encoder_config(
observation_space=input_space, model_config_dict=model_config_dict
)
assert type(model_config) == model_config_type
self.assertEqual(type(model_config), model_config_type)
model = model_config.build(framework=framework)

# Do a forward pass and check if the output has the correct shape
Expand Down Expand Up @@ -252,9 +259,97 @@ def test_get_dist_cls_from_action_space(self):
logits = tf.convert_to_tensor(logits)
# We don't need a model if we input tensors
dist = dist_cls.from_logits(logits=logits)
assert isinstance(dist, expected_cls_dict[framework])
self.assertTrue(isinstance(dist, expected_cls_dict[framework]))
actions = dist.sample()
assert action_space.contains(actions.numpy()[0])
self.assertTrue(action_space.contains(actions.numpy()[0]))

def test_customize_catalog_from_algorithm_config(self):
"""Test if we can pass catalog to algorithm config and it ends up inside
RLModule and is used to build models there."""

class MyCatalog(PPOCatalog):
def build_vf_head(self, framework):
return torch.nn.Linear(self.latent_dims[0], 1)

config = (
PPOConfig()
.rl_module(rl_module_spec=SingleAgentRLModuleSpec(catalog_class=MyCatalog))
.framework("torch")
)

algo = config.build(env="CartPole-v0")
self.assertEqual(
algo.get_policy("default_policy").model.config.catalog_class, MyCatalog
)

# Test if we can pass custom catalog to algorithm config and train with it.

config = (
PPOConfig()
.rl_module(
rl_module_spec=SingleAgentRLModuleSpec(
module_class=PPOTorchRLModule, catalog_class=MyCatalog
)
)
.framework("torch")
)

algo = config.build(env="CartPole-v0")
algo.train()

def test_post_init_overwrite(self):
"""Test if we can overwrite post_init method of a catalog class.
This tests:
- Defines a custom encoder and its config.
- Defines a custom catalog class that uses the custom encoder by
overwriting the __post_init__ method and defining a custom
Catalog.encoder_config.
- Defines a custom RLModule that uses the custom catalog.
- Runs a forward pass through the custom RLModule to check if
everything is working together as expected.
"""
env = gym.make("CartPole-v0")

class MyCostumTorchEncoderConfig(ModelConfig):
def build(self, framework):
return MyCostumTorchEncoder()

class MyCostumTorchEncoder(TorchModel, Encoder):
def __init__(self):
super().__init__({})
self.net = torch.nn.Linear(env.observation_space.shape[0], 10)

def _forward(self, input_dict, **kwargs):
return {
ENCODER_OUT: (self.net(input_dict["obs"])),
STATE_OUT: None,
}

class MyCustomCatalog(PPOCatalog):
def __post_init__(self):
self.action_dist_class_fn = functools.partial(
self.get_dist_cls_from_action_space, action_space=self.action_space
)
self.latent_dims = (10,)
self.encoder_config = MyCostumTorchEncoderConfig(
input_dims=self.observation_space.shape,
output_dims=self.latent_dims,
)

spec = SingleAgentRLModuleSpec(
module_class=PPOTorchRLModule,
observation_space=env.observation_space,
action_space=env.action_space,
model_config_dict=MODEL_DEFAULTS.copy(),
catalog_class=MyCustomCatalog,
)
module = spec.build()

module.forward_inference(
input_data={"obs": torch.ones((32, *env.observation_space.shape))}
)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion rllib/core/models/torch/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class TorchCNN(nn.Module):

def __init__(
self,
input_dims: Union[List, Tuple] = None,
input_dims: Union[List[int], Tuple[int]] = None,
filter_specifiers: List[List[Union[int, List]]] = None,
filter_layer_activation: str = "relu",
output_activation: str = "linear",
Expand Down
35 changes: 32 additions & 3 deletions rllib/core/rl_module/marl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,17 +413,25 @@ def build(
return self.marl_module_class(module_config)

def add_modules(
self, module_specs: Dict[ModuleID, SingleAgentRLModuleSpec]
self,
module_specs: Dict[ModuleID, SingleAgentRLModuleSpec],
overwrite: bool = True,
) -> None:
"""Add new module specs to the spec.
"""Add new module specs to the spec or updates existing ones.
Args:
module_specs: The mapping for the module_id to the single-agent module
specs to be added to this multi-agent module spec.
overwrite: Whether to overwrite the existing module specs if they already
exist. If False, they will be updated only.
"""
if self.module_specs is None:
self.module_specs = {}
self.module_specs.update(module_specs)
for module_id, module_spec in module_specs.items():
if overwrite or module_id not in self.module_specs:
self.module_specs[module_id] = module_spec
else:
self.module_specs[module_id].update(module_spec)

@classmethod
def from_module(self, module: MultiAgentRLModule) -> "MultiAgentRLModuleSpec":
Expand Down Expand Up @@ -452,6 +460,27 @@ def _check_before_build(self):
"SingleAgentRLModuleSpecs for each individual module."
)

def update(self, other: "MultiAgentRLModuleSpec", overwrite=False) -> None:
"""Updates this spec with the other spec.
Traverses this MultiAgentRLModuleSpec's module_specs and updates them with
the module specs from the other MultiAgentRLModuleSpec.
Args:
other: The other spec to update this spec with.
overwrite: Whether to overwrite the existing module specs if they already
exist. If False, they will be updated only.
"""
assert type(other) is MultiAgentRLModuleSpec

if isinstance(other.module_specs, dict):
self.add_modules(other.module_specs, overwrite=overwrite)
else:
if not self.module_specs:
self.module_specs = other.module_specs
else:
self.module_specs.update(other.module_specs)


@ExperimentalAPI
@dataclass
Expand Down
15 changes: 15 additions & 0 deletions rllib/core/rl_module/rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def get_rl_module_config(self) -> "RLModuleConfig":
)

def build(self) -> "RLModule":
if self.module_class is None:
raise ValueError("RLModule class is not set.")
if self.observation_space is None:
raise ValueError("Observation space is not set.")
if self.action_space is None:
Expand Down Expand Up @@ -111,6 +113,19 @@ def from_dict(cls, d):
catalog_class=catalog_class,
)

def update(self, other) -> None:
"""Updates this spec with the given other spec. Works like dict.update()."""
if not isinstance(other, SingleAgentRLModuleSpec):
raise ValueError("Can only update with another SingleAgentRLModuleSpec.")

# If the field is None in the other, keep the current field, otherwise update
# with the new value.
self.module_class = other.module_class or self.module_class
self.observation_space = other.observation_space or self.observation_space
self.action_space = other.action_space or self.action_space
self.model_config_dict = other.model_config_dict or self.model_config_dict
self.catalog_class = other.catalog_class or self.catalog_class


@ExperimentalAPI
@dataclass
Expand Down
Loading

0 comments on commit 753ce33

Please sign in to comment.