-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimpala_agent.py
30 lines (28 loc) · 819 Bytes
/
impala_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import argparse
import pprint
from ray import tune
import ray
from ray.rllib.agents.impala.impala import (
DEFAULT_CONFIG,
ImpalaTrainer as trainer)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--env',
help='Gym env name.')
args = parser.parse_args()
config = DEFAULT_CONFIG.copy()
config_update = {
"env": args.env,
"num_gpus": 1,
"num_workers": 50,
"evaluation_num_workers": 10,
"evaluation_interval": 1
}
config.update(config_update)
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(config)
ray.init()
tune.run(trainer,
stop={"timesteps_total": 2000000},
config=config
)