forked from sjtu-marl/malib
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* init inference worker set * recover default parameter list * add test case for inference server * easy test * fix: no weights update * add comments * fix: env info collect * dockerfile (sjtu-marl#3) Co-authored-by: zhihao lyu <[email protected]> * temp save * temp save * fix: vector env reset * test passed * fix test * implemented impala * tmp save * fix loss compute * update: apply grad norm * apply masked categorical * update kl computation * fix vtrace loss * use full block size * multiple learner support * enable wrapping switch for grfootball * shareable learning mechanism * add hint * update env config * simplify the initilization of server * update test for managers * simplify the implementation of rolloutworker * add some comments * fix: do not repeated create connections * update * independent learning test passed * tmp save * fix dependencies * add unified poker psro case * reformat and docs * reverb backend offline dataset * delete deprecated workers * init monitor module for distributed logging * rollout worker test passed * delete deprecated examples * delete deprecated interface * independent rollout: tmp save * clean env returns * independent rollout: remove episodekey * independent rollout: jumpy dead env * independent rollout: inference test passed * remove deprecated implementation * new backend * migrate envs * remove deprecated rollouts * update rollouts * fix logger path * import tianshou model * format * rollout test passed * update algorithm module * some algorithm * refactor base trainer * rename inference test * fix: inference not compatible with tianshou * fix: version conflicts * reformatted * tmp save * change env path * task scenario mode * tmp save * fix: typing importion * tmp save * tmp save * introduce tianshou batch * batch done * fix data length alignment * delete deprecated tests * support fully async inference server * add ray-based inference cs * mute warning of tianshou * add eval scripts * rename inference type and support local mode * improve FPS via lazy update * fix training error * local mode inference server test pass * psro scenario done * done for basic case Co-authored-by: victor lyu <[email protected]> Co-authored-by: zhihao lyu <[email protected]>
- Loading branch information
1 parent
5be07ac
commit ed687d3
Showing
346 changed files
with
15,296 additions
and
27,993 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
[submodule "malib/envs/smarts/_env"] | ||
path = malib/envs/smarts/_env | ||
[submodule "malib/rollout/envs/smarts/_env"] | ||
path = malib/rollout/envs/smarts/_env | ||
url = https://github.com/huawei-noah/SMARTS.git |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Implemented Algorithms In MALib | ||
|
||
## Population-based Learning Algorithms | ||
|
||
- [x] PSRO: Lanctot, Marc, et al. "A unified game-theoretic approach to multiagent reinforcement learning." Advances in neural information processing systems 30 (2017). [[arXiv](https://arxiv.org/pdf/1711.00832.pdf)] | [[official code](https://github.com/deepmind/open_spiel)] | ||
- [ ] P2SRO: McAleer, Stephen, et al. "Pipeline psro: A scalable approach for finding approximate nash equilibria in large games." Advances in neural information processing systems 33 (2020): 20238-20248. [[arXiv](https://proceedings.neurips.cc/paper/2020/file/e9bcd1b063077573285ae1a41025f5dc-Paper.pdf)] | [[official code](https://github.com/JBLanier/pipeline-psro)] | ||
- [ ] EPSRO: Zhou, Ming, et al. "Efficient Policy Space Response Oracles." arXiv preprint arXiv:2202.00633 (2022). [[arXiv](https://arxiv.org/pdf/2202.00633)] | [offcial code] | ||
- [ ] ODO: Dinh, Le Cong, et al. "Online Double Oracle." arXiv preprint arXiv:2103.07780 (2021). [[arXiv](https://arxiv.org/abs/2103.07780)] | [[official code](https://github.com/npvoid/OnlineDoubleOracle)] | ||
- [ ] XDO: McAleer, Stephen, et al. "XDO: A double oracle algorithm for extensive-form games." Advances in Neural Information Processing Systems 34 (2021): 23128-23139. [[arXiv](https://proceedings.neurips.cc/paper/2021/file/c2e06e9a80370952f6ec5463c77cbace-Paper.pdf)] | [[official code](https://github.com/indylab/nxdo)] | ||
- [ ] NeurPL: Liu, Siqi, et al. "NeuPL: Neural Population Learning." International Conference on Learning Representations. 2021. [[arXiv](https://arxiv.org/abs/2202.07415)] | [official code] | ||
|
||
## Multi-agent Reinforcement Learning Algorithms | ||
|
||
- [x] MADDPG: Lowe, Ryan, et al. "Multi-agent actor-critic for mixed cooperative-competitive environments." Advances in neural information processing systems 30 (2017). [[arXiv](https://arxiv.org/abs/1706.02275)] | ||
- [x] MAPPO: Yu, Chao, et al. "The surprising effectiveness of ppo in cooperative, multi-agent games." arXiv preprint arXiv:2103.01955 (2021). [[arXiv](https://arxiv.org/abs/2103.01955)] | ||
- [x] QMIX: Rashid, Tabish, et al. "Qmix: Monotonic value function factorisation for deep multi-agent reinforcement learning." International conference on machine learning. PMLR, 2018. [[arXiv](http://proceedings.mlr.press/v80/rashid18a/rashid18a.pdf)] | ||
|
||
## Single-agent Reinforcement Learning Algorithms | ||
- [x] A3C: Mnih, Volodymyr, et al. "Asynchronous methods for deep reinforcement learning." International conference on machine learning. PMLR, 2016. [[arXiv](https://arxiv.org/pdf/1602.01783.pdf)] | ||
- [x] DDPG: Lillicrap, Timothy P., et al. "Continuous control with deep reinforcement learning." arXiv preprint arXiv:1509.02971 (2015). [[arXiv](https://arxiv.org/abs/1509.02971)] | ||
- [x] SAC: Haarnoja, Tuomas, et al. "Soft actor-critic: Off-policy maximum entropy deep reinforcement learning with a stochastic actor." International conference on machine learning. PMLR, 2018. [[arXiv](https://arxiv.org/abs/1801.01290)] | ||
- [x] DQN: Mnih, Volodymyr, et al. "Human-level control through deep reinforcement learning." nature 518.7540 (2015): 529-533. [[arXiv](https://arxiv.org/abs/1312.5602)] | ||
- [x] PG: Sutton, Richard S., et al. "Policy gradient methods for reinforcement learning with function approximation." Advances in neural information processing systems 12 (1999). [[arXiv](https://proceedings.neurips.cc/paper/1999/file/464d828b85b0bed98e80ade0a5c43b0f-Paper.pdf)] | ||
- [x] PPO: Schulman, John, et al. "Proximal policy optimization algorithms." arXiv preprint arXiv:1707.06347 (2017). [[arXiv](https://arxiv.org/abs/1707.06347)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from argparse import ArgumentParser | ||
|
||
import random | ||
import time | ||
import pprint | ||
import ray | ||
|
||
from malib.utils.logging import Logger | ||
from malib.utils.timing import Timing | ||
from malib.remote.interface import RemoteInterface | ||
from malib.algorithm.pg import PGPolicy | ||
from malib.rollout.envs.env import Environment | ||
from malib.rollout.envs.gym import env_desc_gen | ||
|
||
|
||
class PolicyServer(RemoteInterface): | ||
def __init__(self, observation_space, action_space, model_config, custom_config): | ||
self.policy = PGPolicy( | ||
observation_space=observation_space, | ||
action_space=action_space, | ||
model_config=model_config, | ||
custom_config=custom_config, | ||
) | ||
|
||
def compute_action(self, observation, action_mask, evaluate): | ||
return self.policy.compute_action( | ||
observation, action_mask=action_mask, evaluate=evaluate | ||
) | ||
|
||
def get_preprocessor(self): | ||
return self.policy.preprocessor | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser("Multi-agent reinforcement learning.") | ||
parser.add_argument("--log_dir", default="./logs/", help="Log directory.") | ||
parser.add_argument("--env_id", default="CartPole-v1", help="gym environment id.") | ||
|
||
args = parser.parse_args() | ||
|
||
env_description = env_desc_gen(env_id=args.env_id, scenario_configs={}) | ||
|
||
obs_spaces = env_description["observation_spaces"] | ||
act_spaces = env_description["action_spaces"] | ||
|
||
# policy = PGPolicy(observation_space=obs_spaces['agent'], action_space=act_spaces["agent"], model_config={}, custom_config={}) | ||
policy = PolicyServer.as_remote().remote( | ||
observation_space=obs_spaces["agent"], | ||
action_space=act_spaces["agent"], | ||
model_config={}, | ||
custom_config={}, | ||
) | ||
preprocessor = ray.get(policy.get_preprocessor.remote()) | ||
|
||
timer = Timing() | ||
|
||
try: | ||
env: Environment = env_description["creator"](**env_description["config"]) | ||
cnt = 0 | ||
Logger.info( | ||
"Performance evaluation started. You can press Ctrl-C to stop evaluation and get performance result..." | ||
) | ||
start = time.time() | ||
|
||
while True: | ||
with timer.time_avg("reset"): | ||
raw_obs = env.reset()[0]["agent"] | ||
|
||
done = False | ||
while not done: | ||
with timer.time_avg("obs_transform"): | ||
obs = preprocessor.transform(raw_obs) | ||
|
||
with timer.time_avg("action_compute"): | ||
action, action_dist, logits, state = ray.get( | ||
policy.compute_action.remote( | ||
obs, action_mask=None, evaluate=random.choice([False, True]) | ||
) | ||
) | ||
|
||
with timer.time_avg("env_step"): | ||
raw_obs, act_mask, rew, done, info = env.step({"agent": action[0]}) | ||
|
||
done = done["agent"] | ||
raw_obs = raw_obs["agent"] | ||
|
||
cnt += 1 | ||
|
||
except KeyboardInterrupt as e: | ||
fps = cnt / (time.time() - start) | ||
Logger.warning( | ||
f"Keyboard interrupt detected, end evaluation. Average performance evaluation:\nFPS = {fps}\nAVG_TIMER={pprint.pformat(timer.todict())}" | ||
) | ||
finally: | ||
env.close() |
Oops, something went wrong.