forked from PWhiddy/PokemonRedExperiments
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrollout.py
61 lines (51 loc) · 2 KB
/
rollout.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from pathlib import Path
import json
from itertools import islice
import numpy as np
import mediapy as media
class Rollout:
def __init__(self, name, agent_name, basepath='rollouts'):
self.path = Path(basepath) / Path(name)
self.actions = []
self.frames = []
self.rewards = []
self.agent_name = agent_name
Path(basepath).mkdir(exist_ok=True)
def add_reward(self, reward):
self.rewards.append(reward)
def add_state_action_pair(self, frame, action):
self.frames.append(frame)
self.actions.append(action)
def get_state_action_pair(self, index):
return (self.frames[index], self.actions[index])
def get_state_action_pairs(self):
return zip(self.frames, self.actions)
def save_to_file(self):
# save frames as video
out_frames = np.array(self.frames)
media.write_video(self.path.with_suffix('.mp4'), out_frames, fps=30)
# save actions and metadata as json
with self.path.with_suffix('.json').open('w') as file:
json.dump({
'actions': self.actions,
'rewards': self.rewards,
'agent': self.agent_name},
file
)
@classmethod
def from_saved_path(cls, path, basepath, limit=None):
with path.open('r') as f:
data = json.load(f)
new_instance = cls(path.stem, data['agent'], basepath=basepath)
new_instance.actions = data['actions']
new_instance.rewards = data['rewards']
if limit is not None:
new_instance.actions = new_instance.actions[:limit]
new_instance.rewards = new_instance.rewards[:limit]
with media.VideoReader(path.with_suffix('.mp4')) as reader:
new_instance.frames = np.array(tuple(islice(reader, limit)))
return new_instance
def load_rollouts(dir_path):
for p in Path(dir_path).glob('*'):
if p.suffix == '.json':
yield Rollout.from_saved_path(p, dir_path)