Skip to content

Commit

Permalink
[RLlib] Fix ConnectorPipelineV2 restoring from checkpoint (by writi…
Browse files Browse the repository at this point in the history
…ng information about individual connector pieces to the `ctor_args_and_kwargs` file). (ray-project#48213)
  • Loading branch information
simonsays1980 authored Oct 27, 2024
1 parent 0b1d0d8 commit 6878aa1
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 3 deletions.
24 changes: 23 additions & 1 deletion rllib/connectors/connector_pipeline_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,18 @@ def __init__(
pipeline during construction. Note that you can always add (or remove)
more ConnectorV2 pieces later on the fly.
"""
self.connectors = connectors or []
self.connectors = []

for conn in connectors:
# If we have a `ConnectorV2` instance just append.
if isinstance(conn, ConnectorV2):
self.connectors.append(conn)
# If, we have a class with `args` and `kwargs`, build the instance.
# Note that this way of constructing a pipeline should only be
# used internally when restoring the pipeline state from a
# checkpoint.
elif isinstance(conn, tuple) and len(conn) == 3:
self.connectors.append(conn[0](*conn[1], **conn[2]))

super().__init__(input_observation_space, input_action_space, **kwargs)

Expand Down Expand Up @@ -266,6 +277,17 @@ def get_checkpointable_components(self) -> List[Tuple[str, "Checkpointable"]]:
# don't have to return the `connectors` c'tor kwarg from there. This is b/c all
# connector pieces in this pipeline are themselves Checkpointable components,
# so they will be properly written into this pipeline's checkpoint.
@override(Checkpointable)
def get_ctor_args_and_kwargs(self) -> Tuple[Tuple, Dict[str, Any]]:
return (
(self.input_observation_space, self.input_action_space), # *args
{
"connectors": [
(type(conn), *conn.get_ctor_args_and_kwargs())
for conn in self.connectors
]
},
)

@override(ConnectorV2)
def reset_state(self) -> None:
Expand Down
3 changes: 2 additions & 1 deletion rllib/connectors/connector_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(
self._action_space = None
self._input_observation_space = None
self._input_action_space = None
self._kwargs = kwargs

self.input_action_space = input_action_space
self.input_observation_space = input_observation_space
Expand Down Expand Up @@ -949,7 +950,7 @@ def set_state(self, state: StateDict) -> None:
def get_ctor_args_and_kwargs(self) -> Tuple[Tuple, Dict[str, Any]]:
return (
(self.input_observation_space, self.input_action_space), # *args
{}, # **kwargs
self._kwargs, # **kwargs
)

def reset_state(self) -> None:
Expand Down
1 change: 0 additions & 1 deletion rllib/tuned_examples/ppo/cartpole_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
)
)


if __name__ == "__main__":
from ray.rllib.utils.test_utils import run_rllib_example_script_experiment

Expand Down

0 comments on commit 6878aa1

Please sign in to comment.