-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclassify_demo.py
82 lines (69 loc) · 3.03 KB
/
classify_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#!/usr/bin/env python3
import sys
from util import get_compressed_images
try:
import cPickle as pickle
except ImportError:
import pickle
class ClassifyDemo(object):
def __init__(
self, net, D, name, train_max_steps, batch_size,
eval_freq, folder):
""" Initialize Classifying Human Demo Training """
self.net = net
self.D = D
self.name = name
self.train_max_steps = train_max_steps
self.batch_size = batch_size
self.eval_freq = eval_freq
self.folder = folder
self._load_memory()
def _load_memory(self):
print ("Loading data")
if self.name == 'pong' or self.name == 'breakout':
# data were pickled using Python 2 which have compatibility issues in Python 3
data = pickle.load(open('{}/{}-dqn-all.pkl'.format(self.folder, self.name), 'rb'), encoding='latin1')
else:
data = pickle.load(open('{}/{}-dqn-all.pkl'.format(self.folder, self.name), 'rb'))
self.D.width = data['D.width']
self.D.height = data['D.height']
self.D.max_steps = data['D.max_steps']
self.D.phi_length = data['D.phi_length']
self.D.num_actions = data['D.num_actions']
self.D.actions = data['D.actions']
self.D.rewards = data['D.rewards']
self.D.terminal = data['D.terminal']
self.D.bottom = data['D.bottom']
self.D.top = data['D.top']
self.D.size = data['D.size']
self.D.imgs = get_compressed_images('{}/{}-dqn-images-all.h5'.format(self.folder, self.name) + '.gz')
print ("Data loaded!")
def run(self):
max_val = -(sys.maxsize)
for i in range(self.train_max_steps):
s_j_batch, a_batch, _, _, _ = self.D.random_batch(self.batch_size)
if (i % self.eval_freq) == 0:
result = self.net.evaluate_batch(s_j_batch, a_batch)
acc = result[0]
summary_str = result[1]
self.net.add_summary(summary_str, i)
print ("step {}, training accuracy {}, max output val {}".format(i, acc, max_val))
# perform gradient step
_, _, _, _, output_vals, max_value = self.net.train(s_j_batch, a_batch)
if max_value > max_val:
max_val = max_value
print ("max output val {}".format(max_val))
self.net.save(model_max_output_val=max_val)
self.save_max_value(max_val=max_val)
def save_max_value(self, max_val=-(sys.maxsize)):
batch = self.D.size * 10 // 100
for i in range(100):
s_j_batch, a_batch, _, _, _ = self.D.random_batch(batch)
_, _, output_vals, max_value = self.net.evaluate_batch(s_j_batch, a_batch)
if i%10 == 0:
print ("step {}, max output val {}".format(i, max_val))
if max_value > max_val:
print ("Max value from {} to {}".format(max_val, max_value))
max_val = max_value
print ("max output val {}".format(max_val))
self.net.save_max_value(max_val)