-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
initial upload of helper modules for DQN algorithm
- Loading branch information
1 parent
67597e5
commit 3a16336
Showing
12 changed files
with
2,339 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import tensorflow as tf | ||
|
||
class batch_norm: | ||
def __init__( | ||
self, inputs, size, is_training,sess, parForTarget=None, | ||
decay=0.9, epsilon=1e-4, slow=False, tau=0.01, linear=False): | ||
""" Initialization of batch_norm class """ | ||
|
||
self.slow = slow | ||
self.sess = sess | ||
self.scale = tf.Variable(tf.random_uniform([size],0.9,1.1), trainable=True, name='scale') | ||
#self.scale = tf.Variable(tf.constant(1.0, shape=[size]), name='scale', trainable=True) | ||
self.beta = tf.Variable(tf.random_uniform([size],-0.03,0.03), trainable=True, name='beta') | ||
#self.beta = tf.Variable(tf.constant(0.0, shape=[size]), name='beta', trainable=True) | ||
self.pop_mean = tf.Variable(tf.random_uniform([size],-0.03,0.03), trainable=False, name='mean') | ||
#self.pop_mean = tf.Variable(tf.constant(0.0, shape=[size]),trainable=False, name='mean') | ||
self.pop_var = tf.Variable(tf.random_uniform([size],0.9,1.1), trainable=False, name='variance') | ||
#self.pop_var = tf.Variable(tf.constant(1.0, shape=[size]),trainable=False, name='variance') | ||
if linear: | ||
self.batch_mean, self.batch_var = tf.nn.moments(inputs,[0]) | ||
else: | ||
self.batch_mean, self.batch_var = tf.nn.moments(inputs,[0,1,2]) | ||
self.train_mean = tf.assign(self.pop_mean,self.pop_mean * decay + self.batch_mean * (1 - decay)) | ||
self.train_var = tf.assign(self.pop_var,self.pop_var * decay + self.batch_var * (1 - decay)) | ||
self.train = tf.group(self.train_mean, self.train_var) | ||
|
||
def training(): | ||
return tf.nn.batch_normalization(inputs, | ||
self.batch_mean, self.batch_var, self.beta, self.scale, epsilon) | ||
|
||
def testing(): | ||
return tf.nn.batch_normalization(inputs, | ||
self.pop_mean, self.pop_var, self.beta, self.scale, epsilon) | ||
|
||
if parForTarget!=None: | ||
self.parForTarget = parForTarget | ||
if self.slow: | ||
self.updateScale = self.scale.assign(self.scale*(1-tau)+self.parForTarget.scale*tau) | ||
self.updateBeta = self.beta.assign(self.beta*(1-tau)+self.parForTarget.beta*tau) | ||
else: | ||
self.updateScale = self.scale.assign(self.parForTarget.scale) | ||
self.updateBeta = self.beta.assign(self.parForTarget.beta) | ||
self.updateTarget = tf.group(self.updateScale, self.updateBeta) | ||
|
||
self.bnorm = tf.cond(is_training,training,testing) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,330 @@ | ||
""" | ||
Copyright (c) 2014, Nathan Sprague | ||
All rights reserved. | ||
Original code: https://goo.gl/dp2qRV | ||
This class stores all of the samples for training. It is able to | ||
construct randomly selected batches of phi's from the stored history. | ||
""" | ||
|
||
import numpy as np | ||
import time | ||
|
||
class DataSet(object): | ||
"""A replay memory consisting of circular buffers for observed images, | ||
actions, and rewards. | ||
""" | ||
def __init__(self, width, height, rng, max_steps=1000, phi_length=4, num_actions=1): | ||
"""Construct a DataSet. | ||
Arguments: | ||
width, height - image size | ||
max_steps - the number of time steps to store | ||
phi_length - number of images to concatenate into a state | ||
rng - initialized numpy random number generator, used to | ||
choose random minibatches | ||
""" | ||
# TODO: Specify capacity in number of state transitions, not | ||
# number of saved time steps. | ||
|
||
# Store arguments. | ||
self.width = width | ||
self.height = height | ||
self.max_steps = max_steps | ||
self.phi_length = phi_length | ||
self.num_actions = num_actions | ||
self.rng = rng | ||
|
||
# Allocate the circular buffers and indices. | ||
self.imgs = np.zeros((max_steps, height, width), dtype='uint8') | ||
self.actions = np.zeros((max_steps, num_actions)) | ||
self.rewards = np.zeros(max_steps, dtype='float') | ||
self.terminal = np.zeros(max_steps, dtype='uint8') | ||
|
||
self.bottom = 0 | ||
self.top = 0 | ||
self.size = 0 | ||
|
||
def add_sample(self, img, action, reward, terminal): | ||
"""Add a time step record. | ||
Arguments: | ||
img -- observed image | ||
action -- action chosen by the agent | ||
reward -- reward received after taking the action | ||
terminal -- boolean indicating whether the episode ended | ||
after this time step | ||
""" | ||
self.imgs[self.top] = img | ||
self.actions[self.top] = action | ||
self.rewards[self.top] = reward | ||
self.terminal[self.top] = terminal | ||
|
||
if self.size == self.max_steps: | ||
self.bottom = (self.bottom + 1) % self.max_steps | ||
else: | ||
self.size += 1 | ||
self.top = (self.top + 1) % self.max_steps | ||
|
||
def __len__(self): | ||
"""Return an approximate count of stored state transitions.""" | ||
# TODO: Properly account for indices which can't be used, as in | ||
# random_batch's check. | ||
return max(0, self.size - self.phi_length) | ||
|
||
def last_phi(self): | ||
"""Return the most recent phi (sequence of image frames).""" | ||
indexes = np.arange(self.top - self.phi_length, self.top) | ||
return self.imgs.take(indexes, axis=0, mode='wrap') | ||
|
||
def phi(self, img): | ||
"""Return a phi (sequence of image frames), using the last phi_length - | ||
1, plus img. | ||
""" | ||
indexes = np.arange(self.top - self.phi_length + 1, self.top) | ||
|
||
phi = np.empty((self.phi_length, self.height, self.width), dtype='float') | ||
phi[0:self.phi_length - 1] = self.imgs.take(indexes, | ||
axis=0, | ||
mode='wrap') | ||
phi[-1] = img | ||
return phi | ||
|
||
def random_batch_classifier(self, batch_size): | ||
"""Return corresponding states, actions, rewards, terminal status, and | ||
next_states for batch_size randomly chosen state transitions. | ||
""" | ||
# Allocate the response. | ||
states = [np.zeros((self.height, | ||
self.width, | ||
self.phi_length), | ||
dtype=np.float64) for i in range(batch_size)] | ||
actions = [[] for i in range(batch_size)] | ||
rewards = [0. for i in range(batch_size)] | ||
terminal = [0 for i in range(batch_size)] | ||
next_states = [np.zeros((self.height, | ||
self.width, | ||
self.phi_length), | ||
dtype=np.float64) for i in range(batch_size)] | ||
|
||
count = 0 | ||
while count < batch_size: | ||
# Randomly choose a time step from the replay memory. | ||
index = self.rng.randint(self.bottom, | ||
self.bottom + self.size - self.phi_length) | ||
|
||
initial_indices = np.arange(index, index + self.phi_length) | ||
transition_indices = initial_indices + 1 | ||
end_index = index + self.phi_length - 1 | ||
|
||
# Check that the initial state corresponds entirely to a | ||
# single episode, meaning none but the last frame may be | ||
# terminal. If the last frame of the initial state is | ||
# terminal, then the last frame of the transitioned state | ||
# will actually be the first frame of a new episode, which | ||
# the Q learner recognizes and handles correctly during | ||
# training by zeroing the discounted future reward estimate. | ||
if np.any(self.terminal.take(initial_indices[0:-1], mode='wrap')): | ||
continue | ||
|
||
# don't train with terminal states | ||
if np.any(self.terminal.take(np.arange(index, index + 10), mode='wrap')): | ||
continue | ||
|
||
# Add the state transition to the response. | ||
temp = self.imgs.take(initial_indices, axis=0, mode='wrap') | ||
temp_ = [ (temp[i]/255.0) for i in range(self.phi_length)] | ||
temp_.reverse() | ||
states[count] = np.stack(tuple(temp_), axis=2) | ||
actions[count] = self.actions.take(end_index, axis=0, mode='wrap') | ||
rewards[count] = self.rewards.take(end_index, mode='wrap') | ||
terminal[count] = self.terminal.take(end_index, mode='wrap') | ||
temp = self.imgs.take(transition_indices,axis=0, mode='wrap') | ||
temp_ = [ (temp[i]/255.0) for i in range(self.phi_length)] | ||
temp_.reverse() | ||
next_states[count] = np.stack(tuple(temp_), axis=2) | ||
|
||
count += 1 | ||
|
||
return states, actions, rewards, next_states, terminal | ||
|
||
def random_batch(self, batch_size): | ||
"""Return corresponding states, actions, rewards, terminal status, and | ||
next_states for batch_size randomly chosen state transitions. | ||
""" | ||
# Allocate the response. | ||
states = [np.zeros((self.height, | ||
self.width, | ||
self.phi_length), | ||
dtype=np.float64) for i in range(batch_size)] | ||
actions = [[] for i in range(batch_size)] | ||
rewards = [0. for i in range(batch_size)] | ||
terminal = [0 for i in range(batch_size)] | ||
next_states = [np.zeros((self.height, | ||
self.width, | ||
self.phi_length), | ||
dtype=np.float64) for i in range(batch_size)] | ||
|
||
count = 0 | ||
while count < batch_size: | ||
# Randomly choose a time step from the replay memory. | ||
index = self.rng.randint(self.bottom, | ||
self.bottom + self.size - self.phi_length) | ||
|
||
initial_indices = np.arange(index, index + self.phi_length) | ||
transition_indices = initial_indices + 1 | ||
end_index = index + self.phi_length - 1 | ||
|
||
# Check that the initial state corresponds entirely to a | ||
# single episode, meaning none but the last frame may be | ||
# terminal. If the last frame of the initial state is | ||
# terminal, then the last frame of the transitioned state | ||
# will actually be the first frame of a new episode, which | ||
# the Q learner recognizes and handles correctly during | ||
# training by zeroing the discounted future reward estimate. | ||
if np.any(self.terminal.take(initial_indices[0:-1], mode='wrap')): | ||
continue | ||
|
||
# Add the state transition to the response. | ||
temp = self.imgs.take(initial_indices, axis=0, mode='wrap') | ||
temp_ = [ (temp[i]/255.0) for i in range(self.phi_length)] | ||
temp_.reverse() | ||
states[count] = np.stack(tuple(temp_), axis=2) | ||
actions[count] = self.actions.take(end_index, axis=0, mode='wrap') | ||
rewards[count] = self.rewards.take(end_index, mode='wrap') | ||
terminal[count] = self.terminal.take(end_index, mode='wrap') | ||
temp = self.imgs.take(transition_indices,axis=0, mode='wrap') | ||
temp_ = [ (temp[i]/255.0) for i in range(self.phi_length)] | ||
temp_.reverse() | ||
next_states[count] = np.stack(tuple(temp_), axis=2) | ||
|
||
count += 1 | ||
|
||
return states, actions, rewards, next_states, terminal | ||
|
||
|
||
# TESTING CODE BELOW THIS POINT... | ||
|
||
def simple_tests(): | ||
np.random.seed(222) | ||
dataset = DataSet(width=2, height=3, | ||
rng=np.random.RandomState(42), | ||
max_steps=6, phi_length=4) | ||
for i in range(10): | ||
img = np.random.randint(0, 256, size=(3, 2)) | ||
action = np.random.randint(16) | ||
reward = np.random.random() | ||
terminal = 0 | ||
if np.random.random() < .05: | ||
terminal = 1 | ||
print 'img', img | ||
dataset.add_sample(img, action, reward, terminal) | ||
print "I", dataset.imgs | ||
print "A", dataset.actions | ||
print "R", dataset.rewards | ||
print "T", dataset.terminal | ||
print "SIZE", dataset.size | ||
print "LAST PHI", dataset.last_phi() | ||
print 'BATCH', dataset.random_batch(2) | ||
|
||
|
||
def speed_tests(): | ||
|
||
dataset = DataSet(width=80, height=80, | ||
rng=np.random.RandomState(42), | ||
max_steps=20000, phi_length=4) | ||
|
||
img = np.random.randint(0, 256, size=(80, 80)) | ||
action = np.random.randint(16) | ||
reward = np.random.random() | ||
start = time.time() | ||
for i in range(100000): | ||
terminal = 0 | ||
if np.random.random() < .05: | ||
terminal = 1 | ||
dataset.add_sample(img, action, reward, terminal) | ||
print "samples per second: ", 100000 / (time.time() - start) | ||
|
||
start = time.time() | ||
for i in range(200): | ||
a = dataset.random_batch(32) | ||
print "batches per second: ", 200 / (time.time() - start) | ||
|
||
print dataset.last_phi() | ||
|
||
|
||
def trivial_tests(): | ||
|
||
dataset = DataSet(width=2, height=1, | ||
rng=np.random.RandomState(42), | ||
max_steps=3, phi_length=2) | ||
|
||
img1 = np.array([[1, 1]], dtype='uint8') | ||
img2 = np.array([[2, 2]], dtype='uint8') | ||
img3 = np.array([[3, 3]], dtype='uint8') | ||
|
||
dataset.add_sample(img1, 1, 1, False) | ||
dataset.add_sample(img2, 2, 2, False) | ||
dataset.add_sample(img3, 2, 2, True) | ||
print "last", dataset.last_phi() | ||
print "random", dataset.random_batch(1) | ||
|
||
|
||
def max_size_tests(): | ||
dataset1 = DataSet(width=3, height=4, | ||
rng=np.random.RandomState(42), | ||
max_steps=10, phi_length=4) | ||
dataset2 = DataSet(width=3, height=4, | ||
rng=np.random.RandomState(42), | ||
max_steps=1000, phi_length=4) | ||
for i in range(100): | ||
img = np.random.randint(0, 256, size=(4, 3)) | ||
action = np.random.randint(16) | ||
reward = np.random.random() | ||
terminal = 0 | ||
if np.random.random() < .05: | ||
terminal = 1 | ||
dataset1.add_sample(img, action, reward, terminal) | ||
dataset2.add_sample(img, action, reward, terminal) | ||
np.testing.assert_array_almost_equal(dataset1.last_phi(), | ||
dataset2.last_phi()) | ||
print "passed" | ||
|
||
|
||
def test_memory_usage_ok(): | ||
import memory_profiler | ||
dataset = DataSet(width=80, height=80, | ||
rng=np.random.RandomState(42), | ||
max_steps=100000, phi_length=4) | ||
last = time.time() | ||
|
||
for i in xrange(1000000000): | ||
if (i % 100000) == 0: | ||
print i | ||
dataset.add_sample(np.random.random((80, 80)), 1, 1, False) | ||
if i > 200000: | ||
states, actions, rewards, next_states, terminals = \ | ||
dataset.random_batch(32) | ||
if (i % 10007) == 0: | ||
print time.time() - last | ||
mem_usage = memory_profiler.memory_usage(-1) | ||
print len(dataset), mem_usage | ||
last = time.time() | ||
|
||
|
||
def main(): | ||
speed_tests() | ||
test_memory_usage_ok() | ||
max_size_tests() | ||
simple_tests() | ||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.