forked from sjtu-marl/malib
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9efda1b
commit 579fd07
Showing
8 changed files
with
144 additions
and
777 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,87 +1,87 @@ | ||
""" | ||
Async optimizer for single-agent RL algorithms running simple scenario from MPE enviroments. In this case, there will be more than one agent training interfaces | ||
used to do policy learning in async mode. Users can specify the number with `--num_learner`. | ||
""" | ||
# """ | ||
# Async optimizer for single-agent RL algorithms running simple scenario from MPE enviroments. In this case, there will be more than one agent training interfaces | ||
# used to do policy learning in async mode. Users can specify the number with `--num_learner`. | ||
# """ | ||
|
||
import argparse | ||
# import argparse | ||
|
||
from malib.envs import MPE | ||
from malib.runner import run | ||
# from malib.envs import MPE | ||
# from malib.runner import run | ||
|
||
|
||
parser = argparse.ArgumentParser("Async training on mpe environments.") | ||
# parser = argparse.ArgumentParser("Async training on mpe environments.") | ||
|
||
parser.add_argument( | ||
"--num_learner", | ||
type=int, | ||
default=3, | ||
help="The number of agent training interfaces. Default by 3.", | ||
) | ||
parser.add_argument( | ||
"--batch_size", type=int, default=64, help="Trianing batch size. Default by 64." | ||
) | ||
parser.add_argument( | ||
"--num_epoch", type=int, default=100, help="Training epoch. Default by 100." | ||
) | ||
parser.add_argument( | ||
"--algorithm", | ||
type=str, | ||
default="DQN", | ||
help="The single-agent RL algortihm registered in MALib. Default by DQN", | ||
) | ||
# parser.add_argument( | ||
# "--num_learner", | ||
# type=int, | ||
# default=3, | ||
# help="The number of agent training interfaces. Default by 3.", | ||
# ) | ||
# parser.add_argument( | ||
# "--batch_size", type=int, default=64, help="Trianing batch size. Default by 64." | ||
# ) | ||
# parser.add_argument( | ||
# "--num_epoch", type=int, default=100, help="Training epoch. Default by 100." | ||
# ) | ||
# parser.add_argument( | ||
# "--algorithm", | ||
# type=str, | ||
# default="DQN", | ||
# help="The single-agent RL algortihm registered in MALib. Default by DQN", | ||
# ) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parser.parse_args() | ||
env_config = { | ||
"scenario_configs": {"max_cycles": 25}, | ||
"env_id": "simple_v2", | ||
} | ||
env = MPE(**env_config) | ||
possible_agents = env.possible_agents | ||
observation_spaces = env.observation_spaces | ||
action_spaces = env.action_spaces | ||
# if __name__ == "__main__": | ||
# args = parser.parse_args() | ||
# env_config = { | ||
# "scenario_configs": {"max_cycles": 25}, | ||
# "env_id": "simple_v2", | ||
# } | ||
# env = MPE(**env_config) | ||
# possible_agents = env.possible_agents | ||
# observation_spaces = env.observation_spaces | ||
# action_spaces = env.action_spaces | ||
|
||
run( | ||
group="MPE/simple", | ||
name="async_dqn", | ||
env_description={ | ||
"creator": MPE, | ||
"config": env_config, | ||
"possible_agents": possible_agents, | ||
}, | ||
agent_mapping_func=lambda agent: [ | ||
f"{agent}_async_{i}" for i in range(args.num_learner) | ||
], | ||
training={ | ||
"interface": { | ||
"type": "async", | ||
"observation_spaces": observation_spaces, | ||
"action_spaces": action_spaces, | ||
"population_size": -1, | ||
}, | ||
"config": { | ||
"update_interval": 1, | ||
"saving_interval": 10, | ||
"batch_size": args.batch_size, | ||
"num_epoch": 100, | ||
"return_gradients": True, | ||
}, | ||
}, | ||
algorithms={ | ||
"Async": {"name": args.algorithm}, | ||
}, | ||
rollout={ | ||
"type": "async", | ||
"stopper": "simple_rollout", | ||
"stopper_config" "metric_type": "simple", | ||
"fragment_length": env_config["scenario_configs"]["max_cycles"], | ||
"num_episodes": 100, # episode for each evaluation/training epoch | ||
"terminate": "any", | ||
}, | ||
global_evaluator={ | ||
"name": "generic", | ||
"config": {"stop_metrics": {}}, | ||
}, | ||
dataset_config={"episode_capacity": 30000}, | ||
) | ||
# run( | ||
# group="MPE/simple", | ||
# name="async_dqn", | ||
# env_description={ | ||
# "creator": MPE, | ||
# "config": env_config, | ||
# "possible_agents": possible_agents, | ||
# }, | ||
# agent_mapping_func=lambda agent: [ | ||
# f"{agent}_async_{i}" for i in range(args.num_learner) | ||
# ], | ||
# training={ | ||
# "interface": { | ||
# "type": "async", | ||
# "observation_spaces": observation_spaces, | ||
# "action_spaces": action_spaces, | ||
# "population_size": -1, | ||
# }, | ||
# "config": { | ||
# "update_interval": 1, | ||
# "saving_interval": 10, | ||
# "batch_size": args.batch_size, | ||
# "num_epoch": 100, | ||
# "return_gradients": True, | ||
# }, | ||
# }, | ||
# algorithms={ | ||
# "Async": {"name": args.algorithm}, | ||
# }, | ||
# rollout={ | ||
# "type": "async", | ||
# "stopper": "simple_rollout", | ||
# "stopper_config" "metric_type": "simple", | ||
# "fragment_length": env_config["scenario_configs"]["max_cycles"], | ||
# "num_episodes": 100, # episode for each evaluation/training epoch | ||
# "terminate": "any", | ||
# }, | ||
# global_evaluator={ | ||
# "name": "generic", | ||
# "config": {"stop_metrics": {}}, | ||
# }, | ||
# dataset_config={"episode_capacity": 30000}, | ||
# ) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,67 +1,67 @@ | ||
""" | ||
Implementation of independent learning applied to MARL cases. | ||
""" | ||
# """ | ||
# Implementation of independent learning applied to MARL cases. | ||
# """ | ||
|
||
import argparse | ||
# import argparse | ||
|
||
from malib.envs.mpe.env import MPE | ||
from malib.runner import run | ||
# from malib.envs.mpe.env import MPE | ||
# from malib.runner import run | ||
|
||
parser = argparse.ArgumentParser( | ||
"Independent multi-agent learning on mpe environments." | ||
) | ||
# parser = argparse.ArgumentParser( | ||
# "Independent multi-agent learning on mpe environments." | ||
# ) | ||
|
||
parser.add_argument("--batch_size", type=int, default=64) | ||
parser.add_argument("--num_epoch", type=int, default=100) | ||
parser.add_argument("--fragment_length", type=int, default=25) | ||
parser.add_argument("--worker_num", type=int, default=6) | ||
parser.add_argument("--algorithm", type=str, default="PPO") | ||
# parser.add_argument("--batch_size", type=int, default=64) | ||
# parser.add_argument("--num_epoch", type=int, default=100) | ||
# parser.add_argument("--fragment_length", type=int, default=25) | ||
# parser.add_argument("--worker_num", type=int, default=6) | ||
# parser.add_argument("--algorithm", type=str, default="PPO") | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parser.parse_args() | ||
env_description = { | ||
"creator": MPE, | ||
"config": { | ||
"env_id": "simple_tag_v2", | ||
"scenario_configs": { | ||
"num_good": 2, | ||
"num_adversaries": 2, | ||
"num_obstacles": 2, | ||
"max_cycles": 25, | ||
}, | ||
}, | ||
} | ||
env = MPE(**env_description["config"]) | ||
env_description["possible_agents"] = env.possible_agents | ||
# if __name__ == "__main__": | ||
# args = parser.parse_args() | ||
# env_description = { | ||
# "creator": MPE, | ||
# "config": { | ||
# "env_id": "simple_tag_v2", | ||
# "scenario_configs": { | ||
# "num_good": 2, | ||
# "num_adversaries": 2, | ||
# "num_obstacles": 2, | ||
# "max_cycles": 25, | ||
# }, | ||
# }, | ||
# } | ||
# env = MPE(**env_description["config"]) | ||
# env_description["possible_agents"] = env.possible_agents | ||
|
||
run( | ||
env_description=env_description, | ||
training={ | ||
"interface": { | ||
"type": "independent", | ||
"observation_spaces": env.observation_spaces, | ||
"action_spaces": env.action_spaces, | ||
}, | ||
"config": { | ||
"agent": { | ||
"observation_spaces": env.observation_spaces, | ||
"action_spaces": env.action_spaces, | ||
}, | ||
"batch_size": args.batch_size, | ||
"grad_norm_clipping": 0.5, | ||
}, | ||
}, | ||
algorithms={"PPO": {"name": "PPO"}}, | ||
rollout={ | ||
"type": "async", | ||
"stopper": "simple_rollout", | ||
"metric_type": "simple", | ||
"fragment_length": 75, | ||
"num_episodes": 100, | ||
}, | ||
global_evaluator={ | ||
"name": "generic", | ||
"config": {"stop_metrics": {}}, | ||
}, | ||
) | ||
# run( | ||
# env_description=env_description, | ||
# training={ | ||
# "interface": { | ||
# "type": "independent", | ||
# "observation_spaces": env.observation_spaces, | ||
# "action_spaces": env.action_spaces, | ||
# }, | ||
# "config": { | ||
# "agent": { | ||
# "observation_spaces": env.observation_spaces, | ||
# "action_spaces": env.action_spaces, | ||
# }, | ||
# "batch_size": args.batch_size, | ||
# "grad_norm_clipping": 0.5, | ||
# }, | ||
# }, | ||
# algorithms={"PPO": {"name": "PPO"}}, | ||
# rollout={ | ||
# "type": "async", | ||
# "stopper": "simple_rollout", | ||
# "metric_type": "simple", | ||
# "fragment_length": 75, | ||
# "num_episodes": 100, | ||
# }, | ||
# global_evaluator={ | ||
# "name": "generic", | ||
# "config": {"stop_metrics": {}}, | ||
# }, | ||
# ) |
Oops, something went wrong.