diff --git a/rllib/BUILD b/rllib/BUILD index 69d187b6a8fe..3806403a56f9 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1848,7 +1848,7 @@ py_test( py_test( name = "test_catalog", tags = ["team:rllib", "core"], - size = "small", + size = "medium", srcs = ["core/models/tests/test_catalog.py"] ) diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index 9fcdcdd30ab0..0888cb428950 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -116,6 +116,15 @@ def _resolve_class_path(module) -> Type: return getattr(module, class_name) +def _check_rl_module_spec(module_spec: ModuleSpec) -> None: + if not isinstance(module_spec, (SingleAgentRLModuleSpec, MultiAgentRLModuleSpec)): + raise ValueError( + "rl_module_spec must be an instance of " + "SingleAgentRLModuleSpec or MultiAgentRLModuleSpec." + f"Got {type(module_spec)} instead." + ) + + class AlgorithmConfig(_Config): """A RLlib AlgorithmConfig builds an RLlib Algorithm from a given configuration. @@ -901,17 +910,26 @@ def validate(self) -> None: # compatibility for now. User only needs to set num_rollout_workers. self.input_config["parallelism"] = self.num_rollout_workers or 1 - # resolve rl_module_spec class - if self._enable_rl_module_api and self.rl_module_spec is None: - self.rl_module_spec = self.get_default_rl_module_spec() - if not isinstance( - self.rl_module_spec, (SingleAgentRLModuleSpec, MultiAgentRLModuleSpec) - ): - raise ValueError( - "rl_module_spec must be an instance of " - "SingleAgentRLModuleSpec or MultiAgentRLModuleSpec." - f"Got {type(self.rl_module_spec)} instead." - ) + if self._enable_rl_module_api: + default_rl_module_spec = self.get_default_rl_module_spec() + _check_rl_module_spec(default_rl_module_spec) + + if self.rl_module_spec is not None: + # Merge provided RL Module spec class with defaults + _check_rl_module_spec(self.rl_module_spec) + # We can only merge if we have SingleAgentRLModuleSpecs. + # TODO(Artur): Support merging for MultiAgentRLModuleSpecs. + if isinstance(self.rl_module_spec, SingleAgentRLModuleSpec): + if isinstance(default_rl_module_spec, SingleAgentRLModuleSpec): + default_rl_module_spec.update(self.rl_module_spec) + self.rl_module_spec = default_rl_module_spec + elif isinstance(default_rl_module_spec, MultiAgentRLModuleSpec): + raise ValueError( + "Cannot merge MultiAgentRLModuleSpec with " + "SingleAgentRLModuleSpec!" + ) + else: + self.rl_module_spec = default_rl_module_spec # make sure the resource requirements for learner_group is valid if self.num_learner_workers == 0 and self.num_gpus_per_worker > 1: @@ -2283,9 +2301,9 @@ def rl_module( Args: rl_module_spec: The RLModule spec to use for this config. It can be either a SingleAgentRLModuleSpec or a MultiAgentRLModuleSpec. If the - observation_space, action_space, or the model_config is not specified - it will be inferred from the env and other parts of the algorithm - config object. + observation_space, action_space, catalog_class, or the model_config is + not specified it will be inferred from the env and other parts of the + algorithm config object. _enable_rl_module_api: Whether to enable the RLModule API for this config. By default if you call `config.rl_module(...)`, the RLModule API will NOT be enabled. If you want to enable it, you can call diff --git a/rllib/core/models/base.py b/rllib/core/models/base.py index 3768dea5ffe2..6d45a1ea8613 100644 --- a/rllib/core/models/base.py +++ b/rllib/core/models/base.py @@ -36,8 +36,8 @@ class ModelConfig(abc.ABC): output_dims: The output dimensions of the network. """ - input_dims: Union[List, Tuple] = None - output_dims: Union[List, Tuple] = None + input_dims: Union[List[int], Tuple[int]] = None + output_dims: Union[List[int], Tuple[int]] = None @abc.abstractmethod def build(self, framework: str): diff --git a/rllib/core/models/catalog.py b/rllib/core/models/catalog.py index 11ba79585030..4230d28f5b11 100644 --- a/rllib/core/models/catalog.py +++ b/rllib/core/models/catalog.py @@ -101,7 +101,8 @@ def latent_dims(self): This establishes an agreement between encoder and heads about the latent dimensions. Encoders can be built to output a latent tensor with `latent_dims` dimensions, and heads can be built with tensors of - `latent_dims` dimensions as inputs. + `latent_dims` dimensions as inputs. This can be safely ignored if this + agreement is not needed in case of modifications to the Catalog. Returns: The latent dimensions of the encoder. @@ -159,7 +160,7 @@ def get_action_dist_cls(self, framework: str): """Get the action distribution class. The default behavior is to get the action distribution from the - `Catalog.action_dist_cls_dict`. This can be overridden to build a custom action + `Catalog.action_dist_class_fn`. This can be overridden to build a custom action distribution as a means of configuring the behavior of a PPORLModuleBase implementation. diff --git a/rllib/core/models/tests/test_catalog.py b/rllib/core/models/tests/test_catalog.py index 5efb7a71a018..73668020e581 100644 --- a/rllib/core/models/tests/test_catalog.py +++ b/rllib/core/models/tests/test_catalog.py @@ -1,15 +1,22 @@ import itertools import unittest +import functools +from collections import namedtuple import gymnasium as gym import numpy as np import tree from gymnasium.spaces import Box, Discrete -from collections import namedtuple +from ray.rllib.algorithms.ppo.ppo import PPOConfig +from ray.rllib.core.models.torch.base import TorchModel +from ray.rllib.core.models.base import ModelConfig, Encoder +from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule +from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog from ray.rllib.core.models.base import STATE_IN, ENCODER_OUT, STATE_OUT -from ray.rllib.core.models.catalog import Catalog from ray.rllib.core.models.configs import MLPEncoderConfig, CNNEncoderConfig +from ray.rllib.core.models.catalog import Catalog +from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec from ray.rllib.models import MODEL_DEFAULTS from ray.rllib.models.tf.tf_distributions import ( TfCategorical, @@ -64,9 +71,9 @@ def _check_model_outputs(self, model, framework, model_config_dict, input_space) } outputs = model(inputs) - assert outputs[ENCODER_OUT].shape == (32, latent_dim) + self.assertEqual(outputs[ENCODER_OUT].shape, (32, latent_dim)) tree.map_structure_with_path( - lambda p, v: self.assertTrue(v.shape == states[p].shape), + lambda p, v: self.assertEqual(v.shape, states[p].shape), outputs[STATE_OUT], ) @@ -177,7 +184,7 @@ def test_get_encoder_config(self): model_config = catalog.get_encoder_config( observation_space=input_space, model_config_dict=model_config_dict ) - assert type(model_config) == model_config_type + self.assertEqual(type(model_config), model_config_type) model = model_config.build(framework=framework) # Do a forward pass and check if the output has the correct shape @@ -252,9 +259,97 @@ def test_get_dist_cls_from_action_space(self): logits = tf.convert_to_tensor(logits) # We don't need a model if we input tensors dist = dist_cls.from_logits(logits=logits) - assert isinstance(dist, expected_cls_dict[framework]) + self.assertTrue(isinstance(dist, expected_cls_dict[framework])) actions = dist.sample() - assert action_space.contains(actions.numpy()[0]) + self.assertTrue(action_space.contains(actions.numpy()[0])) + + def test_customize_catalog_from_algorithm_config(self): + """Test if we can pass catalog to algorithm config and it ends up inside + RLModule and is used to build models there.""" + + class MyCatalog(PPOCatalog): + def build_vf_head(self, framework): + return torch.nn.Linear(self.latent_dims[0], 1) + + config = ( + PPOConfig() + .rl_module(rl_module_spec=SingleAgentRLModuleSpec(catalog_class=MyCatalog)) + .framework("torch") + ) + + algo = config.build(env="CartPole-v0") + self.assertEqual( + algo.get_policy("default_policy").model.config.catalog_class, MyCatalog + ) + + # Test if we can pass custom catalog to algorithm config and train with it. + + config = ( + PPOConfig() + .rl_module( + rl_module_spec=SingleAgentRLModuleSpec( + module_class=PPOTorchRLModule, catalog_class=MyCatalog + ) + ) + .framework("torch") + ) + + algo = config.build(env="CartPole-v0") + algo.train() + + def test_post_init_overwrite(self): + """Test if we can overwrite post_init method of a catalog class. + + This tests: + - Defines a custom encoder and its config. + - Defines a custom catalog class that uses the custom encoder by + overwriting the __post_init__ method and defining a custom + Catalog.encoder_config. + - Defines a custom RLModule that uses the custom catalog. + - Runs a forward pass through the custom RLModule to check if + everything is working together as expected. + + """ + env = gym.make("CartPole-v0") + + class MyCostumTorchEncoderConfig(ModelConfig): + def build(self, framework): + return MyCostumTorchEncoder() + + class MyCostumTorchEncoder(TorchModel, Encoder): + def __init__(self): + super().__init__({}) + self.net = torch.nn.Linear(env.observation_space.shape[0], 10) + + def _forward(self, input_dict, **kwargs): + return { + ENCODER_OUT: (self.net(input_dict["obs"])), + STATE_OUT: None, + } + + class MyCustomCatalog(PPOCatalog): + def __post_init__(self): + self.action_dist_class_fn = functools.partial( + self.get_dist_cls_from_action_space, action_space=self.action_space + ) + self.latent_dims = (10,) + self.encoder_config = MyCostumTorchEncoderConfig( + input_dims=self.observation_space.shape, + output_dims=self.latent_dims, + ) + + spec = SingleAgentRLModuleSpec( + module_class=PPOTorchRLModule, + observation_space=env.observation_space, + action_space=env.action_space, + model_config_dict=MODEL_DEFAULTS.copy(), + catalog_class=MyCustomCatalog, + ) + module = spec.build() + + module.forward_inference( + input_data={"obs": torch.ones((32, *env.observation_space.shape))} + ) if __name__ == "__main__": diff --git a/rllib/core/models/torch/primitives.py b/rllib/core/models/torch/primitives.py index 134359571115..c7bf9af5fc76 100644 --- a/rllib/core/models/torch/primitives.py +++ b/rllib/core/models/torch/primitives.py @@ -71,7 +71,7 @@ class TorchCNN(nn.Module): def __init__( self, - input_dims: Union[List, Tuple] = None, + input_dims: Union[List[int], Tuple[int]] = None, filter_specifiers: List[List[Union[int, List]]] = None, filter_layer_activation: str = "relu", output_activation: str = "linear", diff --git a/rllib/core/rl_module/marl_module.py b/rllib/core/rl_module/marl_module.py index a2be130b555d..eb5ae972589c 100644 --- a/rllib/core/rl_module/marl_module.py +++ b/rllib/core/rl_module/marl_module.py @@ -413,17 +413,25 @@ def build( return self.marl_module_class(module_config) def add_modules( - self, module_specs: Dict[ModuleID, SingleAgentRLModuleSpec] + self, + module_specs: Dict[ModuleID, SingleAgentRLModuleSpec], + overwrite: bool = True, ) -> None: - """Add new module specs to the spec. + """Add new module specs to the spec or updates existing ones. Args: module_specs: The mapping for the module_id to the single-agent module specs to be added to this multi-agent module spec. + overwrite: Whether to overwrite the existing module specs if they already + exist. If False, they will be updated only. """ if self.module_specs is None: self.module_specs = {} - self.module_specs.update(module_specs) + for module_id, module_spec in module_specs.items(): + if overwrite or module_id not in self.module_specs: + self.module_specs[module_id] = module_spec + else: + self.module_specs[module_id].update(module_spec) @classmethod def from_module(self, module: MultiAgentRLModule) -> "MultiAgentRLModuleSpec": @@ -452,6 +460,27 @@ def _check_before_build(self): "SingleAgentRLModuleSpecs for each individual module." ) + def update(self, other: "MultiAgentRLModuleSpec", overwrite=False) -> None: + """Updates this spec with the other spec. + + Traverses this MultiAgentRLModuleSpec's module_specs and updates them with + the module specs from the other MultiAgentRLModuleSpec. + + Args: + other: The other spec to update this spec with. + overwrite: Whether to overwrite the existing module specs if they already + exist. If False, they will be updated only. + """ + assert type(other) is MultiAgentRLModuleSpec + + if isinstance(other.module_specs, dict): + self.add_modules(other.module_specs, overwrite=overwrite) + else: + if not self.module_specs: + self.module_specs = other.module_specs + else: + self.module_specs.update(other.module_specs) + @ExperimentalAPI @dataclass diff --git a/rllib/core/rl_module/rl_module.py b/rllib/core/rl_module/rl_module.py index 292a3a6b03b6..400c4358fdcc 100644 --- a/rllib/core/rl_module/rl_module.py +++ b/rllib/core/rl_module/rl_module.py @@ -57,6 +57,8 @@ def get_rl_module_config(self) -> "RLModuleConfig": ) def build(self) -> "RLModule": + if self.module_class is None: + raise ValueError("RLModule class is not set.") if self.observation_space is None: raise ValueError("Observation space is not set.") if self.action_space is None: @@ -111,6 +113,19 @@ def from_dict(cls, d): catalog_class=catalog_class, ) + def update(self, other) -> None: + """Updates this spec with the given other spec. Works like dict.update().""" + if not isinstance(other, SingleAgentRLModuleSpec): + raise ValueError("Can only update with another SingleAgentRLModuleSpec.") + + # If the field is None in the other, keep the current field, otherwise update + # with the new value. + self.module_class = other.module_class or self.module_class + self.observation_space = other.observation_space or self.observation_space + self.action_space = other.action_space or self.action_space + self.model_config_dict = other.model_config_dict or self.model_config_dict + self.catalog_class = other.catalog_class or self.catalog_class + @ExperimentalAPI @dataclass diff --git a/rllib/core/rl_module/tests/test_rl_module_specs.py b/rllib/core/rl_module/tests/test_rl_module_specs.py index c55c9ce53fd3..11f94aada99b 100644 --- a/rllib/core/rl_module/tests/test_rl_module_specs.py +++ b/rllib/core/rl_module/tests/test_rl_module_specs.py @@ -153,6 +153,121 @@ def test_get_spec_from_module_single_agent(self): spec_from_module = SingleAgentRLModuleSpec.from_module(module) self.assertEqual(spec, spec_from_module) + def test_update_specs(self): + """Tests wether SingleAgentRLModuleSpec.update() works.""" + env = gym.make("CartPole-v0") + + # Test if SingleAgentRLModuleSpec.update() works. + module_spec_1 = SingleAgentRLModuleSpec( + module_class=DiscreteBCTorchModule, + observation_space=env.observation_space, + action_space=env.action_space, + model_config_dict="Update me!", + ) + module_spec_2 = SingleAgentRLModuleSpec( + model_config_dict={"fcnet_hiddens": [32]} + ) + self.assertEqual(module_spec_1.model_config_dict, "Update me!") + module_spec_1.update(module_spec_2) + self.assertEqual(module_spec_1.model_config_dict, {"fcnet_hiddens": [32]}) + + def test_update_specs_multi_agent(self): + """Test if updating a SingleAgentRLModuleSpec in MultiAgentRLModuleSpec works. + + This tests if we can update a `model_config_dict` field through different + kinds of updates: + - Create a SingleAgentRLModuleSpec and update its model_config_dict. + - Create two MultiAgentRLModuleSpecs and update the first one with the + second one without overwriting it. + - Check if the updated MultiAgentRLModuleSpec does not(!) have the + updated model_config_dict. + - Create two MultiAgentRLModuleSpecs and update the first one with the + second one with overwriting it. + - Check if the updated MultiAgentRLModuleSpec has(!) the updated + model_config_dict. + + """ + env = gym.make("CartPole-v0") + + # Test if SingleAgentRLModuleSpec.update() works. + module_spec_1 = SingleAgentRLModuleSpec( + module_class=DiscreteBCTorchModule, + observation_space="Do not update me!", + action_space=env.action_space, + model_config_dict="Update me!", + ) + module_spec_2 = SingleAgentRLModuleSpec( + model_config_dict={"fcnet_hiddens": [32]}, + ) + + self.assertEqual(module_spec_1.model_config_dict, "Update me!") + module_spec_1.update(module_spec_2) + self.assertEqual(module_spec_1.module_class, DiscreteBCTorchModule) + self.assertEqual(module_spec_1.observation_space, "Do not update me!") + self.assertEqual(module_spec_1.action_space, env.action_space) + self.assertEqual( + module_spec_1.model_config_dict, module_spec_2.model_config_dict + ) + + # Redefine module_spec_1 for following tests. + module_spec_1 = SingleAgentRLModuleSpec( + module_class=DiscreteBCTorchModule, + observation_space="Do not update me!", + action_space=env.action_space, + model_config_dict="Update me!", + ) + + marl_spec_1 = MultiAgentRLModuleSpec( + marl_module_class=BCTorchMultiAgentModuleWithSharedEncoder, + module_specs={"agent_1": module_spec_1}, + ) + marl_spec_2 = MultiAgentRLModuleSpec( + marl_module_class=BCTorchMultiAgentModuleWithSharedEncoder, + module_specs={"agent_1": module_spec_2}, + ) + + # Test if updating MultiAgentRLModuleSpec with overwriting works. This means + # that the single agent specs should be overwritten + self.assertEqual( + marl_spec_1.module_specs["agent_1"].model_config_dict, "Update me!" + ) + marl_spec_1.update(marl_spec_2, overwrite=True) + self.assertEqual(marl_spec_1.module_specs["agent_1"], module_spec_2) + + # Test if updating MultiAgentRLModuleSpec without overwriting works. This + # means that the single agent specs should not be overwritten + marl_spec_3 = MultiAgentRLModuleSpec( + marl_module_class=BCTorchMultiAgentModuleWithSharedEncoder, + module_specs={"agent_1": module_spec_1}, + ) + + self.assertEqual( + marl_spec_3.module_specs["agent_1"].observation_space, "Do not update me!" + ) + marl_spec_3.update(marl_spec_2, overwrite=False) + # If we would overwrite, we would replace the observation space even though + # it was None. This is not the case here. + self.assertEqual( + marl_spec_3.module_specs["agent_1"].observation_space, "Do not update me!" + ) + + # Test if updating with an additional SingleAgentRLModuleSpec works. + module_spec_3 = SingleAgentRLModuleSpec( + module_class=DiscreteBCTorchModule, + observation_space=env.observation_space, + action_space=env.action_space, + model_config_dict="I'm new!", + ) + marl_spec_3 = MultiAgentRLModuleSpec( + marl_module_class=BCTorchMultiAgentModuleWithSharedEncoder, + module_specs={"agent_2": module_spec_3}, + ) + self.assertEqual(marl_spec_1.module_specs.get("agent_2"), None) + marl_spec_1.update(marl_spec_3) + self.assertEqual( + marl_spec_1.module_specs["agent_2"].model_config_dict, "I'm new!" + ) + if __name__ == "__main__": import pytest