forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathregistry.py
46 lines (34 loc) · 1.3 KB
/
registry.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
"""Registry of connector names for global access."""
from typing import Any
from ray.rllib.utils.annotations import OldAPIStack
from ray.rllib.connectors.connector import Connector, ConnectorContext
ALL_CONNECTORS = dict()
@OldAPIStack
def register_connector(name: str, cls: Connector):
"""Register a connector for use with RLlib.
Args:
name: Name to register.
cls: Callable that creates an env.
"""
if name in ALL_CONNECTORS:
return
if not issubclass(cls, Connector):
raise TypeError("Can only register Connector type.", cls)
# Record it in local registry in case we need to register everything
# again in the global registry, for example in the event of cluster
# restarts.
ALL_CONNECTORS[name] = cls
@OldAPIStack
def get_connector(name: str, ctx: ConnectorContext, params: Any = None) -> Connector:
# TODO(jungong) : switch the order of parameters man!!
"""Get a connector by its name and serialized config.
Args:
name: name of the connector.
ctx: Connector context.
params: serialized parameters of the connector.
Returns:
Constructed connector.
"""
if name not in ALL_CONNECTORS:
raise NameError("connector not found.", name)
return ALL_CONNECTORS[name].from_state(ctx, params)