diff --git a/rllib/BUILD b/rllib/BUILD
index 69d187b6a8fe..3806403a56f9 100644
--- a/rllib/BUILD
+++ b/rllib/BUILD
@@ -1848,7 +1848,7 @@ py_test(
 py_test(
     name = "test_catalog",
     tags = ["team:rllib", "core"],
-    size = "small",
+    size = "medium",
     srcs = ["core/models/tests/test_catalog.py"]
 )
 
diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py
index 9fcdcdd30ab0..0888cb428950 100644
--- a/rllib/algorithms/algorithm_config.py
+++ b/rllib/algorithms/algorithm_config.py
@@ -116,6 +116,15 @@ def _resolve_class_path(module) -> Type:
         return getattr(module, class_name)
 
 
+def _check_rl_module_spec(module_spec: ModuleSpec) -> None:
+    if not isinstance(module_spec, (SingleAgentRLModuleSpec, MultiAgentRLModuleSpec)):
+        raise ValueError(
+            "rl_module_spec must be an instance of "
+            "SingleAgentRLModuleSpec or MultiAgentRLModuleSpec."
+            f"Got {type(module_spec)} instead."
+        )
+
+
 class AlgorithmConfig(_Config):
     """A RLlib AlgorithmConfig builds an RLlib Algorithm from a given configuration.
 
@@ -901,17 +910,26 @@ def validate(self) -> None:
                 # compatibility for now. User only needs to set num_rollout_workers.
                 self.input_config["parallelism"] = self.num_rollout_workers or 1
 
-        # resolve rl_module_spec class
-        if self._enable_rl_module_api and self.rl_module_spec is None:
-            self.rl_module_spec = self.get_default_rl_module_spec()
-            if not isinstance(
-                self.rl_module_spec, (SingleAgentRLModuleSpec, MultiAgentRLModuleSpec)
-            ):
-                raise ValueError(
-                    "rl_module_spec must be an instance of "
-                    "SingleAgentRLModuleSpec or MultiAgentRLModuleSpec."
-                    f"Got {type(self.rl_module_spec)} instead."
-                )
+        if self._enable_rl_module_api:
+            default_rl_module_spec = self.get_default_rl_module_spec()
+            _check_rl_module_spec(default_rl_module_spec)
+
+            if self.rl_module_spec is not None:
+                # Merge provided RL Module spec class with defaults
+                _check_rl_module_spec(self.rl_module_spec)
+                # We can only merge if we have SingleAgentRLModuleSpecs.
+                # TODO(Artur): Support merging for MultiAgentRLModuleSpecs.
+                if isinstance(self.rl_module_spec, SingleAgentRLModuleSpec):
+                    if isinstance(default_rl_module_spec, SingleAgentRLModuleSpec):
+                        default_rl_module_spec.update(self.rl_module_spec)
+                        self.rl_module_spec = default_rl_module_spec
+                    elif isinstance(default_rl_module_spec, MultiAgentRLModuleSpec):
+                        raise ValueError(
+                            "Cannot merge MultiAgentRLModuleSpec with "
+                            "SingleAgentRLModuleSpec!"
+                        )
+            else:
+                self.rl_module_spec = default_rl_module_spec
 
         # make sure the resource requirements for learner_group is valid
         if self.num_learner_workers == 0 and self.num_gpus_per_worker > 1:
@@ -2283,9 +2301,9 @@ def rl_module(
         Args:
             rl_module_spec: The RLModule spec to use for this config. It can be either
                 a SingleAgentRLModuleSpec or a MultiAgentRLModuleSpec. If the
-                observation_space, action_space, or the model_config is not specified
-                it will be inferred from the env and other parts of the algorithm
-                config object.
+                observation_space, action_space, catalog_class, or the model_config is
+                not specified it will be inferred from the env and other parts of the
+                algorithm config object.
             _enable_rl_module_api: Whether to enable the RLModule API for this config.
                 By default if you call `config.rl_module(...)`, the
                 RLModule API will NOT be enabled. If you want to enable it, you can call
diff --git a/rllib/core/models/base.py b/rllib/core/models/base.py
index 3768dea5ffe2..6d45a1ea8613 100644
--- a/rllib/core/models/base.py
+++ b/rllib/core/models/base.py
@@ -36,8 +36,8 @@ class ModelConfig(abc.ABC):
         output_dims: The output dimensions of the network.
     """
 
-    input_dims: Union[List, Tuple] = None
-    output_dims: Union[List, Tuple] = None
+    input_dims: Union[List[int], Tuple[int]] = None
+    output_dims: Union[List[int], Tuple[int]] = None
 
     @abc.abstractmethod
     def build(self, framework: str):
diff --git a/rllib/core/models/catalog.py b/rllib/core/models/catalog.py
index 11ba79585030..4230d28f5b11 100644
--- a/rllib/core/models/catalog.py
+++ b/rllib/core/models/catalog.py
@@ -101,7 +101,8 @@ def latent_dims(self):
         This establishes an agreement between encoder and heads about the latent
         dimensions. Encoders can be built to output a latent tensor with
         `latent_dims` dimensions, and heads can be built with tensors of
-        `latent_dims` dimensions as inputs.
+        `latent_dims` dimensions as inputs. This can be safely ignored if this
+        agreement is not needed in case of modifications to the Catalog.
 
         Returns:
             The latent dimensions of the encoder.
@@ -159,7 +160,7 @@ def get_action_dist_cls(self, framework: str):
         """Get the action distribution class.
 
         The default behavior is to get the action distribution from the
-        `Catalog.action_dist_cls_dict`. This can be overridden to build a custom action
+        `Catalog.action_dist_class_fn`. This can be overridden to build a custom action
         distribution as a means of configuring the behavior of a PPORLModuleBase
         implementation.
 
diff --git a/rllib/core/models/tests/test_catalog.py b/rllib/core/models/tests/test_catalog.py
index 5efb7a71a018..73668020e581 100644
--- a/rllib/core/models/tests/test_catalog.py
+++ b/rllib/core/models/tests/test_catalog.py
@@ -1,15 +1,22 @@
 import itertools
 import unittest
+import functools
+from collections import namedtuple
 
 import gymnasium as gym
 import numpy as np
 import tree
 from gymnasium.spaces import Box, Discrete
-from collections import namedtuple
 
+from ray.rllib.algorithms.ppo.ppo import PPOConfig
+from ray.rllib.core.models.torch.base import TorchModel
+from ray.rllib.core.models.base import ModelConfig, Encoder
+from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule
+from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
 from ray.rllib.core.models.base import STATE_IN, ENCODER_OUT, STATE_OUT
-from ray.rllib.core.models.catalog import Catalog
 from ray.rllib.core.models.configs import MLPEncoderConfig, CNNEncoderConfig
+from ray.rllib.core.models.catalog import Catalog
+from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
 from ray.rllib.models import MODEL_DEFAULTS
 from ray.rllib.models.tf.tf_distributions import (
     TfCategorical,
@@ -64,9 +71,9 @@ def _check_model_outputs(self, model, framework, model_config_dict, input_space)
         }
         outputs = model(inputs)
 
-        assert outputs[ENCODER_OUT].shape == (32, latent_dim)
+        self.assertEqual(outputs[ENCODER_OUT].shape, (32, latent_dim))
         tree.map_structure_with_path(
-            lambda p, v: self.assertTrue(v.shape == states[p].shape),
+            lambda p, v: self.assertEqual(v.shape, states[p].shape),
             outputs[STATE_OUT],
         )
 
@@ -177,7 +184,7 @@ def test_get_encoder_config(self):
             model_config = catalog.get_encoder_config(
                 observation_space=input_space, model_config_dict=model_config_dict
             )
-            assert type(model_config) == model_config_type
+            self.assertEqual(type(model_config), model_config_type)
             model = model_config.build(framework=framework)
 
             # Do a forward pass and check if the output has the correct shape
@@ -252,9 +259,97 @@ def test_get_dist_cls_from_action_space(self):
                     logits = tf.convert_to_tensor(logits)
                 # We don't need a model if we input tensors
                 dist = dist_cls.from_logits(logits=logits)
-                assert isinstance(dist, expected_cls_dict[framework])
+                self.assertTrue(isinstance(dist, expected_cls_dict[framework]))
                 actions = dist.sample()
-                assert action_space.contains(actions.numpy()[0])
+                self.assertTrue(action_space.contains(actions.numpy()[0]))
+
+    def test_customize_catalog_from_algorithm_config(self):
+        """Test if we can pass catalog to algorithm config and it ends up inside
+        RLModule and is used to build models there."""
+
+        class MyCatalog(PPOCatalog):
+            def build_vf_head(self, framework):
+                return torch.nn.Linear(self.latent_dims[0], 1)
+
+        config = (
+            PPOConfig()
+            .rl_module(rl_module_spec=SingleAgentRLModuleSpec(catalog_class=MyCatalog))
+            .framework("torch")
+        )
+
+        algo = config.build(env="CartPole-v0")
+        self.assertEqual(
+            algo.get_policy("default_policy").model.config.catalog_class, MyCatalog
+        )
+
+        # Test if we can pass custom catalog to algorithm config and train with it.
+
+        config = (
+            PPOConfig()
+            .rl_module(
+                rl_module_spec=SingleAgentRLModuleSpec(
+                    module_class=PPOTorchRLModule, catalog_class=MyCatalog
+                )
+            )
+            .framework("torch")
+        )
+
+        algo = config.build(env="CartPole-v0")
+        algo.train()
+
+    def test_post_init_overwrite(self):
+        """Test if we can overwrite post_init method of a catalog class.
+
+        This tests:
+            - Defines a custom encoder and its config.
+            - Defines a custom catalog class that uses the custom encoder by
+                overwriting the __post_init__ method and defining a custom
+                Catalog.encoder_config.
+            - Defines a custom RLModule that uses the custom catalog.
+            - Runs a forward pass through the custom RLModule to check if
+                everything is working together as expected.
+
+        """
+        env = gym.make("CartPole-v0")
+
+        class MyCostumTorchEncoderConfig(ModelConfig):
+            def build(self, framework):
+                return MyCostumTorchEncoder()
+
+        class MyCostumTorchEncoder(TorchModel, Encoder):
+            def __init__(self):
+                super().__init__({})
+                self.net = torch.nn.Linear(env.observation_space.shape[0], 10)
+
+            def _forward(self, input_dict, **kwargs):
+                return {
+                    ENCODER_OUT: (self.net(input_dict["obs"])),
+                    STATE_OUT: None,
+                }
+
+        class MyCustomCatalog(PPOCatalog):
+            def __post_init__(self):
+                self.action_dist_class_fn = functools.partial(
+                    self.get_dist_cls_from_action_space, action_space=self.action_space
+                )
+                self.latent_dims = (10,)
+                self.encoder_config = MyCostumTorchEncoderConfig(
+                    input_dims=self.observation_space.shape,
+                    output_dims=self.latent_dims,
+                )
+
+        spec = SingleAgentRLModuleSpec(
+            module_class=PPOTorchRLModule,
+            observation_space=env.observation_space,
+            action_space=env.action_space,
+            model_config_dict=MODEL_DEFAULTS.copy(),
+            catalog_class=MyCustomCatalog,
+        )
+        module = spec.build()
+
+        module.forward_inference(
+            input_data={"obs": torch.ones((32, *env.observation_space.shape))}
+        )
 
 
 if __name__ == "__main__":
diff --git a/rllib/core/models/torch/primitives.py b/rllib/core/models/torch/primitives.py
index 134359571115..c7bf9af5fc76 100644
--- a/rllib/core/models/torch/primitives.py
+++ b/rllib/core/models/torch/primitives.py
@@ -71,7 +71,7 @@ class TorchCNN(nn.Module):
 
     def __init__(
         self,
-        input_dims: Union[List, Tuple] = None,
+        input_dims: Union[List[int], Tuple[int]] = None,
         filter_specifiers: List[List[Union[int, List]]] = None,
         filter_layer_activation: str = "relu",
         output_activation: str = "linear",
diff --git a/rllib/core/rl_module/marl_module.py b/rllib/core/rl_module/marl_module.py
index a2be130b555d..eb5ae972589c 100644
--- a/rllib/core/rl_module/marl_module.py
+++ b/rllib/core/rl_module/marl_module.py
@@ -413,17 +413,25 @@ def build(
         return self.marl_module_class(module_config)
 
     def add_modules(
-        self, module_specs: Dict[ModuleID, SingleAgentRLModuleSpec]
+        self,
+        module_specs: Dict[ModuleID, SingleAgentRLModuleSpec],
+        overwrite: bool = True,
     ) -> None:
-        """Add new module specs to the spec.
+        """Add new module specs to the spec or updates existing ones.
 
         Args:
             module_specs: The mapping for the module_id to the single-agent module
                 specs to be added to this multi-agent module spec.
+            overwrite: Whether to overwrite the existing module specs if they already
+                exist. If False, they will be updated only.
         """
         if self.module_specs is None:
             self.module_specs = {}
-        self.module_specs.update(module_specs)
+        for module_id, module_spec in module_specs.items():
+            if overwrite or module_id not in self.module_specs:
+                self.module_specs[module_id] = module_spec
+            else:
+                self.module_specs[module_id].update(module_spec)
 
     @classmethod
     def from_module(self, module: MultiAgentRLModule) -> "MultiAgentRLModuleSpec":
@@ -452,6 +460,27 @@ def _check_before_build(self):
                 "SingleAgentRLModuleSpecs for each individual module."
             )
 
+    def update(self, other: "MultiAgentRLModuleSpec", overwrite=False) -> None:
+        """Updates this spec with the other spec.
+
+        Traverses this MultiAgentRLModuleSpec's module_specs and updates them with
+        the module specs from the other MultiAgentRLModuleSpec.
+
+        Args:
+            other: The other spec to update this spec with.
+            overwrite: Whether to overwrite the existing module specs if they already
+                exist. If False, they will be updated only.
+        """
+        assert type(other) is MultiAgentRLModuleSpec
+
+        if isinstance(other.module_specs, dict):
+            self.add_modules(other.module_specs, overwrite=overwrite)
+        else:
+            if not self.module_specs:
+                self.module_specs = other.module_specs
+            else:
+                self.module_specs.update(other.module_specs)
+
 
 @ExperimentalAPI
 @dataclass
diff --git a/rllib/core/rl_module/rl_module.py b/rllib/core/rl_module/rl_module.py
index 292a3a6b03b6..400c4358fdcc 100644
--- a/rllib/core/rl_module/rl_module.py
+++ b/rllib/core/rl_module/rl_module.py
@@ -57,6 +57,8 @@ def get_rl_module_config(self) -> "RLModuleConfig":
         )
 
     def build(self) -> "RLModule":
+        if self.module_class is None:
+            raise ValueError("RLModule class is not set.")
         if self.observation_space is None:
             raise ValueError("Observation space is not set.")
         if self.action_space is None:
@@ -111,6 +113,19 @@ def from_dict(cls, d):
             catalog_class=catalog_class,
         )
 
+    def update(self, other) -> None:
+        """Updates this spec with the given other spec. Works like dict.update()."""
+        if not isinstance(other, SingleAgentRLModuleSpec):
+            raise ValueError("Can only update with another SingleAgentRLModuleSpec.")
+
+        # If the field is None in the other, keep the current field, otherwise update
+        # with the new value.
+        self.module_class = other.module_class or self.module_class
+        self.observation_space = other.observation_space or self.observation_space
+        self.action_space = other.action_space or self.action_space
+        self.model_config_dict = other.model_config_dict or self.model_config_dict
+        self.catalog_class = other.catalog_class or self.catalog_class
+
 
 @ExperimentalAPI
 @dataclass
diff --git a/rllib/core/rl_module/tests/test_rl_module_specs.py b/rllib/core/rl_module/tests/test_rl_module_specs.py
index c55c9ce53fd3..11f94aada99b 100644
--- a/rllib/core/rl_module/tests/test_rl_module_specs.py
+++ b/rllib/core/rl_module/tests/test_rl_module_specs.py
@@ -153,6 +153,121 @@ def test_get_spec_from_module_single_agent(self):
             spec_from_module = SingleAgentRLModuleSpec.from_module(module)
             self.assertEqual(spec, spec_from_module)
 
+    def test_update_specs(self):
+        """Tests wether SingleAgentRLModuleSpec.update() works."""
+        env = gym.make("CartPole-v0")
+
+        # Test if SingleAgentRLModuleSpec.update() works.
+        module_spec_1 = SingleAgentRLModuleSpec(
+            module_class=DiscreteBCTorchModule,
+            observation_space=env.observation_space,
+            action_space=env.action_space,
+            model_config_dict="Update me!",
+        )
+        module_spec_2 = SingleAgentRLModuleSpec(
+            model_config_dict={"fcnet_hiddens": [32]}
+        )
+        self.assertEqual(module_spec_1.model_config_dict, "Update me!")
+        module_spec_1.update(module_spec_2)
+        self.assertEqual(module_spec_1.model_config_dict, {"fcnet_hiddens": [32]})
+
+    def test_update_specs_multi_agent(self):
+        """Test if updating a SingleAgentRLModuleSpec in MultiAgentRLModuleSpec works.
+
+        This tests if we can update a `model_config_dict` field through different
+        kinds of updates:
+            - Create a SingleAgentRLModuleSpec and update its model_config_dict.
+            - Create two MultiAgentRLModuleSpecs and update the first one with the
+                second one without overwriting it.
+            - Check if the updated MultiAgentRLModuleSpec does not(!) have the
+                updated model_config_dict.
+            - Create two MultiAgentRLModuleSpecs and update the first one with the
+                second one with overwriting it.
+            - Check if the updated MultiAgentRLModuleSpec has(!) the updated
+                model_config_dict.
+
+        """
+        env = gym.make("CartPole-v0")
+
+        # Test if SingleAgentRLModuleSpec.update() works.
+        module_spec_1 = SingleAgentRLModuleSpec(
+            module_class=DiscreteBCTorchModule,
+            observation_space="Do not update me!",
+            action_space=env.action_space,
+            model_config_dict="Update me!",
+        )
+        module_spec_2 = SingleAgentRLModuleSpec(
+            model_config_dict={"fcnet_hiddens": [32]},
+        )
+
+        self.assertEqual(module_spec_1.model_config_dict, "Update me!")
+        module_spec_1.update(module_spec_2)
+        self.assertEqual(module_spec_1.module_class, DiscreteBCTorchModule)
+        self.assertEqual(module_spec_1.observation_space, "Do not update me!")
+        self.assertEqual(module_spec_1.action_space, env.action_space)
+        self.assertEqual(
+            module_spec_1.model_config_dict, module_spec_2.model_config_dict
+        )
+
+        # Redefine module_spec_1 for following tests.
+        module_spec_1 = SingleAgentRLModuleSpec(
+            module_class=DiscreteBCTorchModule,
+            observation_space="Do not update me!",
+            action_space=env.action_space,
+            model_config_dict="Update me!",
+        )
+
+        marl_spec_1 = MultiAgentRLModuleSpec(
+            marl_module_class=BCTorchMultiAgentModuleWithSharedEncoder,
+            module_specs={"agent_1": module_spec_1},
+        )
+        marl_spec_2 = MultiAgentRLModuleSpec(
+            marl_module_class=BCTorchMultiAgentModuleWithSharedEncoder,
+            module_specs={"agent_1": module_spec_2},
+        )
+
+        # Test if updating MultiAgentRLModuleSpec with overwriting works. This means
+        # that the single agent specs should be overwritten
+        self.assertEqual(
+            marl_spec_1.module_specs["agent_1"].model_config_dict, "Update me!"
+        )
+        marl_spec_1.update(marl_spec_2, overwrite=True)
+        self.assertEqual(marl_spec_1.module_specs["agent_1"], module_spec_2)
+
+        # Test if updating MultiAgentRLModuleSpec without overwriting works. This
+        # means that the single agent specs should not be overwritten
+        marl_spec_3 = MultiAgentRLModuleSpec(
+            marl_module_class=BCTorchMultiAgentModuleWithSharedEncoder,
+            module_specs={"agent_1": module_spec_1},
+        )
+
+        self.assertEqual(
+            marl_spec_3.module_specs["agent_1"].observation_space, "Do not update me!"
+        )
+        marl_spec_3.update(marl_spec_2, overwrite=False)
+        # If we would overwrite, we would replace the observation space even though
+        # it was None. This is not the case here.
+        self.assertEqual(
+            marl_spec_3.module_specs["agent_1"].observation_space, "Do not update me!"
+        )
+
+        # Test if updating with an additional SingleAgentRLModuleSpec works.
+        module_spec_3 = SingleAgentRLModuleSpec(
+            module_class=DiscreteBCTorchModule,
+            observation_space=env.observation_space,
+            action_space=env.action_space,
+            model_config_dict="I'm new!",
+        )
+        marl_spec_3 = MultiAgentRLModuleSpec(
+            marl_module_class=BCTorchMultiAgentModuleWithSharedEncoder,
+            module_specs={"agent_2": module_spec_3},
+        )
+        self.assertEqual(marl_spec_1.module_specs.get("agent_2"), None)
+        marl_spec_1.update(marl_spec_3)
+        self.assertEqual(
+            marl_spec_1.module_specs["agent_2"].model_config_dict, "I'm new!"
+        )
+
 
 if __name__ == "__main__":
     import pytest