forked from eloialonso/diamond
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathagent.py
62 lines (52 loc) · 2.08 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from dataclasses import dataclass
from pathlib import Path
from typing import Union
import torch
import torch.nn as nn
from envs import TorchEnv, WorldModelEnv
from models.actor_critic import ActorCritic, ActorCriticConfig, ActorCriticLossConfig
from models.diffusion import Denoiser, DenoiserConfig, SigmaDistributionConfig
from models.rew_end_model import RewEndModel, RewEndModelConfig
from utils import extract_state_dict
@dataclass
class AgentConfig:
denoiser: DenoiserConfig
rew_end_model: RewEndModelConfig
actor_critic: ActorCriticConfig
num_actions: int
def __post_init__(self) -> None:
self.denoiser.inner_model.num_actions = self.num_actions
self.rew_end_model.num_actions = self.num_actions
self.actor_critic.num_actions = self.num_actions
class Agent(nn.Module):
def __init__(self, cfg: AgentConfig) -> None:
super().__init__()
self.denoiser = Denoiser(cfg.denoiser)
self.rew_end_model = RewEndModel(cfg.rew_end_model)
self.actor_critic = ActorCritic(cfg.actor_critic)
@property
def device(self):
return self.denoiser.device
def setup_training(
self,
sigma_distribution_cfg: SigmaDistributionConfig,
actor_critic_loss_cfg: ActorCriticLossConfig,
rl_env: Union[TorchEnv, WorldModelEnv],
) -> None:
self.denoiser.setup_training(sigma_distribution_cfg)
self.actor_critic.setup_training(rl_env, actor_critic_loss_cfg)
def load(
self,
path_to_ckpt: Path,
load_denoiser: bool = True,
load_rew_end_model: bool = True,
load_actor_critic: bool = True,
) -> None:
sd = torch.load(Path(path_to_ckpt), map_location=self.device)
sd = {k: extract_state_dict(sd, k) for k in ("denoiser", "rew_end_model", "actor_critic")}
if load_denoiser:
self.denoiser.load_state_dict(sd["denoiser"])
if load_rew_end_model:
self.rew_end_model.load_state_dict(sd["rew_end_model"])
if load_actor_critic:
self.actor_critic.load_state_dict(sd["actor_critic"])