Skip to content

Commit

Permalink
update example
Browse files Browse the repository at this point in the history
  • Loading branch information
KornbergFresnel committed May 11, 2021
1 parent 22a8fec commit e393e59
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 37 deletions.
4 changes: 2 additions & 2 deletions examples/appo_simple.py → examples/async_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from malib.runner import run


parser = argparse.ArgumentParser("Async PPO training on mpe environments.")
parser = argparse.ArgumentParser("Async training on mpe environments.")

parser.add_argument("--num_learner", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=64)
Expand All @@ -26,7 +26,7 @@

run(
group="MPE/simple",
name="async_ppo",
name="async_dqn",
worker_config={"worker_num": args.num_learner},
env_description={
"creator": simple_v2.env,
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/maddpg_push_ball_nips.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ rollout:
evaluation:
fragment_length: 25
num_episodes: 100
callback: "simultaneous"
callback: "sequential"

env_description:
# scenario_name: "simple_spread"
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/maddpg_simple_spread.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ rollout:
evaluation:
fragment_length: 100
num_episodes: 100
callback: "simultaneous"
callback: "sequential"

env_description:
# scenario_name: "simple_spread"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from malib.rollout.rollout_worker import RolloutWorker
from malib.rollout.rollout_func import rollout_wrapper

from malib.backend.datapool.offline_dataset_server import Episode
from malib.backend.datapool.offline_dataset_server import MultiAgentEpisode
from malib.utils.logger import Log, get_logger


Expand Down Expand Up @@ -113,38 +113,30 @@
rollout_handler.update_population(agent, pid, policy)
trainable_policy_mapping[agent] = pid

agent_episodes = {
agent: Episode(
env_desc["id"],
policy_id=trainable_policy_mapping[agent],
capacity=config["dataset_config"]["episode_capacity"],
other_columns=["times"],
)
for agent in env.possible_agents
}
agent_episode = MultiAgentEpisode(
env_desc["id"],
trainable_policy_mapping,
capacity=config["dataset_config"]["episode_capacity"],
)
# ======================================================================================================

# ==================================== Main loop =======================================================
min_size = 0
while min_size < config["dataset_config"]["learning_start"]:
while agent_episode.size < config["dataset_config"]["learning_start"]:
statistics, _ = rollout_handler.sample(
callback=rollout_wrapper(
agent_episodes, rollout_type=config["rollout"]["callback"]
),
callback=rollout_wrapper(agent_episode),
trainable_pairs=trainable_policy_mapping,
fragment_length=config["rollout"]["fragment_length"],
behavior_policy_mapping=trainable_policy_mapping,
num_episodes=[config["rollout"]["episode_seg"]],
role="rollout",
threaded=False,
)
min_size = min([e.size for e in agent_episodes.values()])

for epoch in range(1000):
print(f"==================== epoch #{epoch} ===============")
_ = rollout_handler.sample(
callback=rollout_wrapper(
agent_episodes, rollout_type=config["rollout"]["callback"]
),
callback=rollout_wrapper(agent_episode),
fragment_length=config["rollout"]["fragment_length"],
behavior_policy_mapping=trainable_policy_mapping,
num_episodes=[config["rollout"]["episode_seg"]],
Expand All @@ -155,9 +147,7 @@
log=True, logger=logger, worker_idx="Rollout", global_step=epoch
) as (statistic_seq, processed_statistics):
statistics, _ = rollout_handler.sample(
callback=rollout_wrapper(
None, rollout_type=config["evaluation"]["callback"]
),
callback=rollout_wrapper(None),
fragment_length=config["rollout"]["fragment_length"],
behavior_policy_mapping=trainable_policy_mapping,
num_episodes=[config["evaluation"]["num_episodes"]],
Expand All @@ -170,15 +160,12 @@
print("-------------- rollout --------")
pprint.pprint(processed_statistics[0])

idxes = np.random.choice(min_size, config["training"]["config"]["batch_size"])
batches = {}
for agent in env.possible_agents:
batch = agent_episodes[agent].sample(idxes)
batch = agent_episode.sample(
size=config["training"]["config"]["batch_size"]
)
batches[agent] = batch
# check timestamp
for i in range(config["training"]["config"]["batch_size"]):
times = [b["times"][i] for b in batches.values()]
assert max(times) == min(times), (max(times), min(times))

print("-------------- traininig -------")
for iid, interface in learners.items():
Expand Down
1 change: 1 addition & 0 deletions examples/single_instance_psro.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def training_workflow(trainable_policy_mapping: Dict[AgentID, PolicyID]):
num_episodes=[rollout_config["num_episodes"]],
threaded=False,
role="rollout",
trainable_pairs=trainable_policy_mapping,
)
batch = agent_episodes[agent].sample(size=args.batch_size)
res = learners[agent].optimize(
Expand Down
3 changes: 3 additions & 0 deletions malib/backend/coordinator/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def __init__(
BaseCoordinator.__init__(self)

self._configs = kwargs
self._configs["rollout"]["worker_num"] = self._configs["worker_config"].get(
"worker_num", -1
)

self._terminate = False
self._pending_trainable_pairs = {}
Expand Down
13 changes: 8 additions & 5 deletions malib/rollout/rollout_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@

import numpy as np

from malib.backend.datapool.offline_dataset_server import Episode
from malib.backend.datapool.offline_dataset_server import Episode, MultiAgentEpisode
from malib.utils.metrics import get_metric
from malib.utils.typing import AgentID, Dict
from malib.utils.typing import AgentID, Dict, Union
from malib.utils.preprocessor import get_preprocessor


Expand Down Expand Up @@ -231,11 +231,12 @@ def simultaneous(


def rollout_wrapper(
agent_episodes: Dict[AgentID, Episode] = None, rollout_type="sequential"
agent_episodes: Union[MultiAgentEpisode, Dict[AgentID, Episode]] = None,
rollout_type="sequential",
):
"""Rollout wrapper accept a dict of episodes outside.
:param Dict[AgentID,Episode] agent_episodes: A dict of agent episodes.
:param Union[MultiAgentEpisode,Dict[AgentID,Episode]] agent_episodes: A dict of agent episodes or multiagentepisode instance.
:param str rollout_type: Specify rollout styles. Default to `sequential`, choices={sequential, simultaneous}.
:return: A function
"""
Expand All @@ -258,7 +259,9 @@ def func(
max_iter,
behavior_policy_mapping=behavior_policy_mapping,
)
if agent_episodes is not None:
if isinstance(agent_episodes, MultiAgentEpisode):
agent_episodes.insert(**episodes)
elif isinstance(agent_episodes, Dict):
for agent, episode in episodes.items():
agent_episodes[agent].insert(**episode.data)
return statistic, episodes
Expand Down
3 changes: 2 additions & 1 deletion malib/rollout/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,11 @@ def _rollout(
aid: Episode.concatenate(*merged_data[aid], capacity=merged_capacity[aid])
for aid in merged_data
}
ap_mapping = {k: v.policy_id for k, v in agent_episode.items()}
data2send = {
aid: MultiAgentEpisode(
e.env_id,
kwargs["trainable_pairs"],
ap_mapping,
merged_capacity[aid],
e.other_columns,
)
Expand Down

0 comments on commit e393e59

Please sign in to comment.