Skip to content

Commit

Permalink
Fix bug: batch data key error
Browse files Browse the repository at this point in the history
  • Loading branch information
KornbergFresnel committed Oct 16, 2021
1 parent 9efda1b commit 579fd07
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 777 deletions.
158 changes: 79 additions & 79 deletions examples/async_simple.py
Original file line number Diff line number Diff line change
@@ -1,87 +1,87 @@
"""
Async optimizer for single-agent RL algorithms running simple scenario from MPE enviroments. In this case, there will be more than one agent training interfaces
used to do policy learning in async mode. Users can specify the number with `--num_learner`.
"""
# """
# Async optimizer for single-agent RL algorithms running simple scenario from MPE enviroments. In this case, there will be more than one agent training interfaces
# used to do policy learning in async mode. Users can specify the number with `--num_learner`.
# """

import argparse
# import argparse

from malib.envs import MPE
from malib.runner import run
# from malib.envs import MPE
# from malib.runner import run


parser = argparse.ArgumentParser("Async training on mpe environments.")
# parser = argparse.ArgumentParser("Async training on mpe environments.")

parser.add_argument(
"--num_learner",
type=int,
default=3,
help="The number of agent training interfaces. Default by 3.",
)
parser.add_argument(
"--batch_size", type=int, default=64, help="Trianing batch size. Default by 64."
)
parser.add_argument(
"--num_epoch", type=int, default=100, help="Training epoch. Default by 100."
)
parser.add_argument(
"--algorithm",
type=str,
default="DQN",
help="The single-agent RL algortihm registered in MALib. Default by DQN",
)
# parser.add_argument(
# "--num_learner",
# type=int,
# default=3,
# help="The number of agent training interfaces. Default by 3.",
# )
# parser.add_argument(
# "--batch_size", type=int, default=64, help="Trianing batch size. Default by 64."
# )
# parser.add_argument(
# "--num_epoch", type=int, default=100, help="Training epoch. Default by 100."
# )
# parser.add_argument(
# "--algorithm",
# type=str,
# default="DQN",
# help="The single-agent RL algortihm registered in MALib. Default by DQN",
# )


if __name__ == "__main__":
args = parser.parse_args()
env_config = {
"scenario_configs": {"max_cycles": 25},
"env_id": "simple_v2",
}
env = MPE(**env_config)
possible_agents = env.possible_agents
observation_spaces = env.observation_spaces
action_spaces = env.action_spaces
# if __name__ == "__main__":
# args = parser.parse_args()
# env_config = {
# "scenario_configs": {"max_cycles": 25},
# "env_id": "simple_v2",
# }
# env = MPE(**env_config)
# possible_agents = env.possible_agents
# observation_spaces = env.observation_spaces
# action_spaces = env.action_spaces

run(
group="MPE/simple",
name="async_dqn",
env_description={
"creator": MPE,
"config": env_config,
"possible_agents": possible_agents,
},
agent_mapping_func=lambda agent: [
f"{agent}_async_{i}" for i in range(args.num_learner)
],
training={
"interface": {
"type": "async",
"observation_spaces": observation_spaces,
"action_spaces": action_spaces,
"population_size": -1,
},
"config": {
"update_interval": 1,
"saving_interval": 10,
"batch_size": args.batch_size,
"num_epoch": 100,
"return_gradients": True,
},
},
algorithms={
"Async": {"name": args.algorithm},
},
rollout={
"type": "async",
"stopper": "simple_rollout",
"stopper_config" "metric_type": "simple",
"fragment_length": env_config["scenario_configs"]["max_cycles"],
"num_episodes": 100, # episode for each evaluation/training epoch
"terminate": "any",
},
global_evaluator={
"name": "generic",
"config": {"stop_metrics": {}},
},
dataset_config={"episode_capacity": 30000},
)
# run(
# group="MPE/simple",
# name="async_dqn",
# env_description={
# "creator": MPE,
# "config": env_config,
# "possible_agents": possible_agents,
# },
# agent_mapping_func=lambda agent: [
# f"{agent}_async_{i}" for i in range(args.num_learner)
# ],
# training={
# "interface": {
# "type": "async",
# "observation_spaces": observation_spaces,
# "action_spaces": action_spaces,
# "population_size": -1,
# },
# "config": {
# "update_interval": 1,
# "saving_interval": 10,
# "batch_size": args.batch_size,
# "num_epoch": 100,
# "return_gradients": True,
# },
# },
# algorithms={
# "Async": {"name": args.algorithm},
# },
# rollout={
# "type": "async",
# "stopper": "simple_rollout",
# "stopper_config" "metric_type": "simple",
# "fragment_length": env_config["scenario_configs"]["max_cycles"],
# "num_episodes": 100, # episode for each evaluation/training epoch
# "terminate": "any",
# },
# global_evaluator={
# "name": "generic",
# "config": {"stop_metrics": {}},
# },
# dataset_config={"episode_capacity": 30000},
# )
1 change: 0 additions & 1 deletion examples/configs/mpe/ddpg_simple_nips.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ rollout:
episode_seg: 25
terminate: "any"
# can be sequential or simultaneous
callback: "sequential"

evaluation:
fragment_length: 25
Expand Down
120 changes: 60 additions & 60 deletions examples/independent_marl.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,67 @@
"""
Implementation of independent learning applied to MARL cases.
"""
# """
# Implementation of independent learning applied to MARL cases.
# """

import argparse
# import argparse

from malib.envs.mpe.env import MPE
from malib.runner import run
# from malib.envs.mpe.env import MPE
# from malib.runner import run

parser = argparse.ArgumentParser(
"Independent multi-agent learning on mpe environments."
)
# parser = argparse.ArgumentParser(
# "Independent multi-agent learning on mpe environments."
# )

parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--num_epoch", type=int, default=100)
parser.add_argument("--fragment_length", type=int, default=25)
parser.add_argument("--worker_num", type=int, default=6)
parser.add_argument("--algorithm", type=str, default="PPO")
# parser.add_argument("--batch_size", type=int, default=64)
# parser.add_argument("--num_epoch", type=int, default=100)
# parser.add_argument("--fragment_length", type=int, default=25)
# parser.add_argument("--worker_num", type=int, default=6)
# parser.add_argument("--algorithm", type=str, default="PPO")


if __name__ == "__main__":
args = parser.parse_args()
env_description = {
"creator": MPE,
"config": {
"env_id": "simple_tag_v2",
"scenario_configs": {
"num_good": 2,
"num_adversaries": 2,
"num_obstacles": 2,
"max_cycles": 25,
},
},
}
env = MPE(**env_description["config"])
env_description["possible_agents"] = env.possible_agents
# if __name__ == "__main__":
# args = parser.parse_args()
# env_description = {
# "creator": MPE,
# "config": {
# "env_id": "simple_tag_v2",
# "scenario_configs": {
# "num_good": 2,
# "num_adversaries": 2,
# "num_obstacles": 2,
# "max_cycles": 25,
# },
# },
# }
# env = MPE(**env_description["config"])
# env_description["possible_agents"] = env.possible_agents

run(
env_description=env_description,
training={
"interface": {
"type": "independent",
"observation_spaces": env.observation_spaces,
"action_spaces": env.action_spaces,
},
"config": {
"agent": {
"observation_spaces": env.observation_spaces,
"action_spaces": env.action_spaces,
},
"batch_size": args.batch_size,
"grad_norm_clipping": 0.5,
},
},
algorithms={"PPO": {"name": "PPO"}},
rollout={
"type": "async",
"stopper": "simple_rollout",
"metric_type": "simple",
"fragment_length": 75,
"num_episodes": 100,
},
global_evaluator={
"name": "generic",
"config": {"stop_metrics": {}},
},
)
# run(
# env_description=env_description,
# training={
# "interface": {
# "type": "independent",
# "observation_spaces": env.observation_spaces,
# "action_spaces": env.action_spaces,
# },
# "config": {
# "agent": {
# "observation_spaces": env.observation_spaces,
# "action_spaces": env.action_spaces,
# },
# "batch_size": args.batch_size,
# "grad_norm_clipping": 0.5,
# },
# },
# algorithms={"PPO": {"name": "PPO"}},
# rollout={
# "type": "async",
# "stopper": "simple_rollout",
# "metric_type": "simple",
# "fragment_length": 75,
# "num_episodes": 100,
# },
# global_evaluator={
# "name": "generic",
# "config": {"stop_metrics": {}},
# },
# )
Loading

0 comments on commit 579fd07

Please sign in to comment.