-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathlpg_ftw_agent.py
176 lines (149 loc) · 5.37 KB
/
lpg_ftw_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
from copy import deepcopy
import typing
import logging
import numpy as np
import gym
import tella
from categ_mlp_lpg_ftw import MLPLPGFTW
from npg_cg_ftw import NPGFTW
from mlp_baselines import MLPBaseline
logger = logging.getLogger("LPG_FTW Agent")
DEVICE = "cuda:0"
# Constants copied from experiments.habitat_ste_m15.py
BASELINE_TRAINING_EPOCH = 20
NORMALIZED_STEP_SIZE = 0.00005
HVP_SAMPLEFRAC = 0.02
BATCH_SIZE = 128
N = 50
GAMMA = 0.995
GAE_LAMBDA = None # 0.97
POLICY_HIDDEN_SIZE = 128
BASELINE_HIDDEN_SIZE = 128
K=1
MAX_K=4
BASELINE_LR = 1e-5
class LpgFtwAgent(tella.ContinualRLAgent):
def __init__(
self,
rng_seed: int,
observation_space: gym.Space,
action_space: gym.Space,
num_envs: int,
config_file: typing.Optional[str] = None,
) -> None:
rng_seed = rng_seed % 2**32 # exceed the range limit of numpy seeding
super(LpgFtwAgent, self).__init__(
rng_seed, observation_space, action_space, num_envs, config_file
)
baselines = {}
policy = MLPLPGFTW(
observation_space,
action_space,
hidden_size=POLICY_HIDDEN_SIZE,
k=K, max_k=MAX_K,
seed=rng_seed,
use_gpu=(DEVICE != "cpu")
)
self.agent_train = NPGFTW(
policy,
baselines,
num_envs=num_envs,
normalized_step_size=NORMALIZED_STEP_SIZE,
seed=rng_seed,
use_gpu=(DEVICE != "cpu"),
hvp_sample_frac=HVP_SAMPLEFRAC,
batch_size=BATCH_SIZE
)
self.agent = self.agent_train
self.train = None # True for learning_block and False for evaluation_block
def block_start(self, is_learning_allowed: bool) -> None:
super().block_start(is_learning_allowed)
if is_learning_allowed:
logger.info("About to start a new learning block")
self.training = True
self.agent = self.agent_train
else:
logger.info("About to start a new evaluation block")
self.training = False
def task_start(self, task_name: typing.Optional[str]) -> None:
logger.info(
f"\tAbout to start interacting with a new task. task_name={task_name}"
)
if not task_name in self.agent.baselines.keys():
self.agent.baselines[task_name] = MLPBaseline(
self.observation_space,
reg_coef=1e-3,
batch_size=BATCH_SIZE,
epochs=BASELINE_TRAINING_EPOCH,
learn_rate=BASELINE_LR,
use_gpu=(DEVICE != "cpu")
)
self.agent.set_task(task_name)
self.agent.rollout_buffer.clear_log()
def choose_actions(
self, observations: typing.List[typing.Optional[tella.Observation]]
) -> typing.List[typing.Optional[tella.Action]]:
# Don't know whether torch.no_grad is needed or not
# In original code, they didn't use torch._no_grad for eval
actions = []
obs_new = [obs if obs is not None else self.observation_space.sample() for obs in observations]
obs_new = np.stack(obs_new)
obs_new = obs_new.reshape(obs_new.shape[0], -1)
acts, act_infos = self.agent.policy.get_action(obs_new)
actions = []
for a, ai, obs in zip(acts, act_infos['evaluation'], observations):
if obs is None:
actions.append(None)
else:
if self.training:
actions.append(a)
else:
actions.append(ai)
return actions
def receive_transitions(self, transitions: typing.List[typing.Optional[tella.Transition]]) -> None:
assert len(transitions) == self.num_envs
if not self.is_learning_allowed:
return
transitions = [self.flat_observation(t) for t in transitions]
self.agent.train_step(
transitions,
N=N, gamma=GAMMA, gae_lambda=GAE_LAMBDA
)
def flat_observation(self, transition: typing.Optional[tella.Transition]):
if transition is None:
return transition
else:
s, a, r, d, ns = transition
s = s.reshape(-1)
return s, a, r, d, ns
def task_variant_start(
self,
task_name: typing.Optional[str],
variant_name: typing.Optional[str],
) -> None:
logger.info(
f"\tAbout to start interacting with a new task variant. "
f"task_name={task_name} variant_name={variant_name}"
)
def task_end(
self,
task_name: typing.Optional[str],
) -> None:
logger.info(f"\tDone interacting with task. task_name={task_name}")
def task_variant_end(
self,
task_name: typing.Optional[str],
variant_name: typing.Optional[str],
) -> None:
logger.info(
f"\tDone interacting with task variant. "
f"task_name={task_name} variant_name={variant_name}"
)
def block_end(self, is_learning_allowed: bool) -> None:
if is_learning_allowed:
logger.info("Done with learning block")
else:
logger.info("Done with evaluation block")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
tella.rl_cli(LpgFtwAgent)