Skip to content

Commit

Permalink
in progress: pb-rollout-worker in test
Browse files Browse the repository at this point in the history
  • Loading branch information
Ming Zhou committed Nov 27, 2023
1 parent 207bcae commit 40c5074
Show file tree
Hide file tree
Showing 17 changed files with 170 additions and 189 deletions.
98 changes: 0 additions & 98 deletions examples/run_gym.py

This file was deleted.

6 changes: 4 additions & 2 deletions examples/run_psro.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import os
import time

from malib.runner import run
from malib.scenarios import psro_scenario
from malib.learner import IndependentAgent
from malib.scenarios.psro_scenario import PSROScenario
from malib.rl.dqn import DQNPolicy, DQNTrainer, DEFAULT_CONFIG
Expand Down Expand Up @@ -99,4 +99,6 @@
},
)

run(scenario)
results = psro_scenario.execution_plan(scenario=scenario, verbose=True)

print(results)
39 changes: 35 additions & 4 deletions examples/sarl/ppo_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,44 @@

from argparse import ArgumentParser

from gym import spaces

import numpy as np

from malib.utils.episode import Episode
from malib.learner import IndependentAgent
from malib.scenarios import sarl_scenario
from malib.rl.config import Algorithm
from malib.rl.ppo import PPOPolicy, PPOTrainer, DEFAULT_CONFIG
from malib.learner.config import LearnerConfig
from malib.rollout.config import RolloutConfig
from malib.rollout.envs.gym import env_desc_gen
from malib.backend.dataset_server.feature import BaseFeature


class FeatureHandler(BaseFeature):
pass


def feature_handler_meta_gen(env_desc, agent_id):
def f(device):
# define the data schema
_spaces = {
Episode.DONE: spaces.Discrete(1),
Episode.CUR_OBS: env_desc["observation_spaces"][agent_id],
Episode.ACTION: env_desc["action_spaces"][agent_id],
Episode.REWARD: spaces.Box(-np.inf, np.inf, shape=(1,), dtype=np.float32),
Episode.NEXT_OBS: env_desc["observation_spaces"][agent_id],
}

# you should know the maximum of replaybuffer before training
np_memory = {
k: np.zeros((100,) + v.shape, dtype=v.dtype) for k, v in _spaces.items()
}

return FeatureHandler(_spaces, np_memory, device)

return f


if __name__ == "__main__":
Expand Down Expand Up @@ -43,7 +74,7 @@
),
learner_config=LearnerConfig(
learner_type=IndependentAgent,
feature_handler_meta_gen=None,
feature_handler_meta_gen=feature_handler_meta_gen,
custom_config={},
),
rollout_config=RolloutConfig(
Expand All @@ -56,6 +87,6 @@
},
)

results = sarl_scenario.execution_plan(
experiment_tag=scenario.name, scenario=scenario, verbose=True
)
results = sarl_scenario.execution_plan(scenario=scenario, verbose=True)

print(results)
6 changes: 4 additions & 2 deletions malib/backend/dataset_server/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ def __init__(
super().__init__()

# start a service as thread
self.feature_handler: BaseFeature = feature_handler or feature_handler_cls(
**feature_handler_kwargs
self.feature_handler: BaseFeature = (
feature_handler
if feature_handler is not None
else feature_handler_cls(**feature_handler_kwargs)
)
self.grpc_thread_num_workers = grpc_thread_num_workers
self.max_message_length = max_message_length
Expand Down
2 changes: 2 additions & 0 deletions malib/learner/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
class LearnerConfig:
learner_type: Type[Learner]
feature_handler_meta_gen: Callable[["EnvDesc", str], Callable[[str], BaseFeature]]
"""what is it?"""

custom_config: Dict[str, Any] = field(default_factory=dict())

@classmethod
Expand Down
5 changes: 4 additions & 1 deletion malib/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ def __init__(
algorithm.trainer_config, self._policy
)

# since the trainer_config has been updated by the trainer
# thus the algorithm should update its trainer_config
algorithm.trainer_config = self._trainer.training_config

if dataset is None:
dataset = DynamicDataset(
grpc_thread_num_workers=2,
Expand All @@ -126,7 +130,6 @@ def __init__(
dataset.feature_handler = feature_handler_gen(device)

dataset.start_server()

self._data_loader = DataLoader(
dataset, batch_size=algorithm.trainer_config["batch_size"]
)
Expand Down
12 changes: 4 additions & 8 deletions malib/learner/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ def __init__(
learner_cls = learner_cls.as_remote(**resource_config)
learners: Dict[str, ray.ObjectRef] = {}

assert (
"training" in stopping_conditions
), f"Stopping conditions should contains `training` stoppong conditions: {stopping_conditions}"
# assert (
# "training" in stopping_conditions
# ), f"Stopping conditions should contains `training` stoppong conditions: {stopping_conditions}"

ready_check = []

Expand All @@ -123,7 +123,6 @@ def __init__(
algorithm=algorithm,
agent_mapping_func=agent_mapping_func,
governed_agents=agents,
trainer_config=algorithm.trainer_config,
custom_config=learner_config.custom_config,
feature_handler_gen=learner_config.feature_handler_meta_gen(
env_desc, agents[0]
Expand Down Expand Up @@ -236,10 +235,7 @@ def add_policies(
policy_nums = dict.fromkeys(interface_ids, n) if isinstance(n, int) else n

strategy_spec_list: List[StrategySpec] = ray.get(
[
self._learners[k].add_policies.remote(n=policy_nums[k])
for k in interface_ids
]
[self._learners[k].get_strategy_spec.remote() for k in interface_ids]
)
strategy_spec_dict: Dict[str, StrategySpec] = dict(
zip(interface_ids, strategy_spec_list)
Expand Down
6 changes: 4 additions & 2 deletions malib/models/model_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Any
from typing import Dict, Any, Union
from concurrent import futures

import threading
Expand All @@ -19,7 +19,9 @@ def load_state_dict(client, timeout=10):


class ModelClient:
def __init__(self, entry_point: str, model_config: ModelConfig):
def __init__(
self, entry_point: str, model_config: Union[ModelConfig, Dict[str, Any]]
):
"""Construct a model client for mantaining a model instance and its update.
Args:
Expand Down
23 changes: 14 additions & 9 deletions malib/rl/common/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ def __init__(
self._model = kwargs.get("model_client")
if self._model is None:
if kwargs.get("model_entry_point"):
self._model = ModelClient(kwargs["model_entry_point"], model_config)
self._model = ModelClient(
kwargs["model_entry_point"],
ModelConfig(lambda **x: self.create_model(), model_config),
)
else:
self._model = self.create_model().to(self._device)

Expand Down Expand Up @@ -147,7 +150,7 @@ def preprocessor(self) -> Preprocessor:
return self._preprocessor

@property
def device(self) -> str:
def device(self) -> torch.device:
return self._device

@property
Expand Down Expand Up @@ -186,7 +189,7 @@ def state_dict(
res = self.model.state_dict()
else:
res = {}
for k, v in self.model.state_dict():
for k, v in self.model.state_dict().items():
res[k] = v.to(device)

return res
Expand Down Expand Up @@ -249,16 +252,18 @@ def to(self, device: str = None, use_copy: bool = False) -> "Policy":
Policy: A policy instance
"""

if isinstance(device, torch.device):
device = device.type
if isinstance(device, str):
device = torch.device(device)

if device is None:
device = "cpu" if "cuda" not in self.device else "cuda"
device = (
torch.device("cpu") if "cuda" not in self.device.type else self.device
)

cond1 = "cpu" in device and "cuda" in self.device
cond2 = "cuda" in device and "cuda" not in self.device
cond1 = "cpu" in device.type and "cuda" in self.device.type
cond2 = "cuda" in device.type and "cuda" not in self.device.type

if "cpu" in device:
if "cpu" in device.type:
_device = device
else:
_device = self.device
Expand Down
1 change: 1 addition & 0 deletions malib/rl/pg/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"reward_norm": None,
"n_repeat": 2,
"minibatch": 2,
"batch_size": 32,
"gamma": 0.99,
},
"model_config": {
Expand Down
1 change: 1 addition & 0 deletions malib/rl/pg/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def create_model(self):
self.model_config["preprocess_net"].get("net_type", None),
**self.model_config["preprocess_net"]["config"]
)

if isinstance(self.action_space, spaces.Discrete):
return discrete.Actor(
preprocess_net=preprocess_net,
Expand Down
12 changes: 11 additions & 1 deletion malib/rl/pg/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,28 @@
import numpy as np

from torch import optim
from malib.rl.common.policy import Policy

from malib.rl.common.trainer import Trainer
from malib.utils.data import Postprocessor
from malib.utils.general import merge_dicts
from malib.utils.typing import AgentID
from malib.utils.tianshou_batch import Batch
from .config import DEFAULT_CONFIG


class PGTrainer(Trainer):
def __init__(self, training_config: Dict[str, Any], policy_instance: Policy = None):
# merge from default
training_config = merge_dicts(
DEFAULT_CONFIG["training_config"], training_config or {}
)
super().__init__(training_config, policy_instance)

def setup(self):
self.optimizer: Type[optim.Optimizer] = getattr(
optim, self.training_config["optimizer"]
)(self.policy.parameters()["actor"], lr=self.training_config["lr"])
)(self.policy.actor.parameters(), lr=self.training_config["lr"])
self.lr_scheduler: torch.optim.lr_scheduler.LambdaLR = None
self.ret_rms = None

Expand Down
18 changes: 16 additions & 2 deletions malib/rl/random/random_trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
from typing import Any, Dict
from typing import Any, Dict, Type

import torch

from torch import optim

from malib.rl.common.policy import Policy
from malib.rl.pg.trainer import PGTrainer


class RandomTrainer(PGTrainer):
pass
def __init__(self, training_config: Dict[str, Any], policy_instance: Policy = None):
super().__init__(training_config, policy_instance)

def setup(self):
self.optimizer: Type[optim.Optimizer] = getattr(
optim, self.training_config["optimizer"]
)(self.policy.parameters(), lr=self.training_config["lr"])
self.lr_scheduler: torch.optim.lr_scheduler.LambdaLR = None
self.ret_rms = None
Loading

0 comments on commit 40c5074

Please sign in to comment.