-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtrain_agent.py
95 lines (79 loc) · 2.56 KB
/
train_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
import signal
from typing import Callable
import gymnasium as gym
import hydra
import numpy as np
import torch as th
from omegaconf import DictConfig, OmegaConf
from src import BaseRLAgent, create_agent
from src.utils.drls.env import get_env_info, make_env, reset_env_fn
from src.utils.exp.prepare import set_random_seed
from src.utils.logger import TBLogger
from src.utils.net.ptu import (
save_torch_model,
set_eval_mode,
set_torch,
set_train_mode,
tensor2ndarray,
)
@th.no_grad
def eval_policy(
eval_env: gym.Env,
reset_env_fn: Callable,
policy: BaseRLAgent,
seed: int,
episodes=10,
):
"""Evaluate Policy"""
set_eval_mode(policy.models)
returns = []
for _ in range(episodes):
(state, _), terminated, truncated = reset_env_fn(eval_env, seed), False, False
return_ = 0.0
while not (terminated or truncated):
action = policy.select_action(
state,
deterministic=True,
return_log_prob=False,
**{"action_space": eval_env.action_space},
)
state, reward, terminated, truncated, _ = eval_env.step(
tensor2ndarray((action,))[0]
)
return_ += reward
returns.append(return_)
set_train_mode(policy.models)
# average
return np.mean(returns)
@hydra.main(config_path="./conf", config_name="train_agent", version_base="1.3.2")
def main(cfg: DictConfig):
cfg.work_dir = os.getcwd()
# prepare experiment
set_torch()
set_random_seed(cfg.seed)
# setup logger
logger = TBLogger(
args=OmegaConf.to_object(cfg),
record_param=cfg.log.record_param,
console_output=cfg.log.console_output,
)
# setup environment
train_env, eval_env = (make_env(cfg.env.id), make_env(cfg.env.id))
OmegaConf.update(cfg, "env[info]", get_env_info(eval_env), merge=False)
# create agent
agent = create_agent(cfg)
# train agent
def ctr_c_handler(_signum, _frame):
"""If the program was stopped by ctr+c, we will save the model before leaving"""
logger.console.warning("The program is stopped...")
logger.console.info(
save_torch_model(agent.models, logger.ckpt_dir, "stopped_model")
) # save model
exit(1)
signal.signal(signal.SIGINT, ctr_c_handler)
agent.learn(train_env, eval_env, reset_env_fn, eval_policy, logger)
# save model
logger.console.info(save_torch_model(agent.models, logger.ckpt_dir, "final_model"))
if __name__ == "__main__":
main()