forked from facebookresearch/mbrl-lib
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
86 lines (67 loc) · 2.89 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Iterable
import hydra
import numpy as np
import omegaconf
import torch
import mbrl.algorithms.mbpo as mbpo
import mbrl.algorithms.pets as pets
import mbrl.algorithms.planet as planet
import mbrl.algorithms.dreamer as dreamer #added April 2022 for project
import mbrl.util.env
import pandas as pd
from collections import Iterable
import wandb
def flatten_config(cfg, curr_nested_key):
"""The nested config file provided by Hydra cannot be parsed by wandb. This recursive function flattens the config file, separating the nested keys and their parents via an underscore. Allows for easier configuration using wandb.
Args:
cfg (Hydra config): The nested config file used by Hydra.
curr_nested_key (str): The current parent key (used for recursive calls).
Returns:
(dict): A flatt configuration dictionary.
"""
flat_cfg = {}
for curr_key in cfg.keys():
# deal with missing values
try:
curr_item = cfg[curr_key]
except Exception as e:
curr_item = 'NA'
# deal with lists
if type(curr_item) == list or type(curr_item) == omegaconf.listconfig.ListConfig:
for nested_idx, nested_item in enumerate(curr_item):
list_nested_key = f"{curr_nested_key}_{curr_key}_{nested_idx}"
flat_cfg[list_nested_key] = nested_item
# check if item is also a config
# recurse
elif isinstance(curr_item, Iterable) and type(curr_item) != str:
flat_cfg.update(flatten_config(curr_item, f"{curr_nested_key}_{curr_key}"))
# otherwise just add to return dict
else:
flat_cfg[f"{curr_nested_key}_{curr_key}"] = curr_item
return flat_cfg
@hydra.main(config_path="conf", config_name="main")
def run(cfg: omegaconf.DictConfig):
env, term_fn, reward_fn = mbrl.util.env.EnvHandler.make_env(cfg)
for config_item in cfg:
wandb.config[config_item] = cfg[config_item]
flat_cfg = flatten_config(cfg, "")
for config_item in flat_cfg:
wandb.config[config_item] = flat_cfg[config_item]
np.random.seed(cfg.seed)
torch.manual_seed(cfg.seed)
if cfg.algorithm.name == "pets":
return pets.train(env, term_fn, reward_fn, cfg)
if cfg.algorithm.name == "mbpo":
test_env, *_ = mbrl.util.env.EnvHandler.make_env(cfg)
return mbpo.train(env, test_env, term_fn, cfg)
if cfg.algorithm.name == "planet":
return planet.train(env, cfg)
if cfg.algorithm.name == "dreamer": #added for project
return dreamer.train(env, cfg)
if __name__ == "__main__":
wandb.init(project="MBRL_Duckyt", entity="mbrl_ducky", monitor_gym=True)
run()