Skip to content

Commit

Permalink
adding tf driver
Browse files Browse the repository at this point in the history
  • Loading branch information
crizCraig committed Dec 17, 2016
1 parent 9f687cb commit 3f91375
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 22 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ _Thanks to [Rafal jozefowicz](https://github.com/rafaljozefowicz) for contribut
```
cd drivers/deepdrive-tf
wget -O model.ckpt-20048 https://goo.gl/zanx88
wget -O model.ckpt-20048.meta https://goo.gl/LNqHoj
```


Expand Down
27 changes: 13 additions & 14 deletions driver_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,31 +21,37 @@ def __init__(self):
def load_net(self):
raise NotImplementedError('Show us how to load your net')

def set_net_input(self, image):
def set_input(self, image):
raise NotImplementedError('Show us how to set your net\'s inputs')

def get_next_action_n(self, net_out, info):
def get_next_action(self, net_out, info):
raise NotImplementedError('Show us how to get the next action given the net\'s outputs')

def get_net_out(self):
raise NotImplementedError('Show us how to get output from your net')

def setup(self):
self.load_net()

def process_step(self, observation_n, reward_n, done_n, info):
def step(self, observation_n, reward_n, done_n, info):
if observation_n[0] is None:
return self.get_noop()
image = observation_n[0]['vision']
if image is not None:
begin = time.time()
self.set_net_input(image)
self.set_input(image)
end = time.time()
logger.debug('time to set net input %s', end - begin)
net_out = self.react()

begin = time.time()
next_action_n = self.get_next_action_n(net_out, info)
net_out = self.get_net_out()
end = time.time()
logger.debug('time to get next action %s', end - begin)
logger.debug('time to get net out %s', end - begin)

begin = time.time()
next_action_n = self.get_next_action(net_out, info)
end = time.time()
logger.debug('time to get next action %s', end - begin)

errored = [i for i, info_i in enumerate(info['n']) if 'error' in info_i]
if errored:
Expand All @@ -64,10 +70,3 @@ def get_noop(self):
z_axis_event = JoystickAxisZEvent(0)
noop = [[x_axis_event, z_axis_event]]
return noop

def react(self):
begin = time.time()
net_out = self.net.forward()
end = time.time()
logger.debug('inference time %s', end - begin)
return net_out
15 changes: 11 additions & 4 deletions drivers/deepdrive/deep_driver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import caffe
import os
from driver_base import DriverBase
from universe.spaces.joystick_event import JoystickAxisXEvent, JoystickAxisZEvent
Expand All @@ -14,6 +13,7 @@ def __init__(self):
self.input_layer_name = 'images'

def load_net(self):
import caffe # Don't require caffe unless this driver is used
caffe.set_mode_gpu()
model_def = os.path.join(DIR_PATH, 'deep_drive_model.prototxt')
model_weights = os.path.join(DIR_PATH, 'caffe_deep_drive_train_iter_35352.caffemodel')
Expand All @@ -25,7 +25,7 @@ def load_net(self):
transformer.set_channel_swap('data', (2, 1, 0)) # swap channels from RGB to BGR
self.image_transformer = transformer

def get_next_action_n(self, net_out, info):
def get_next_action(self, net_out, info):
spin, direction, speed, speed_change, steer, throttle = net_out['gtanet_fctop'][0]
steer = -float(steer)
steer_dead_zone = 0.2
Expand Down Expand Up @@ -60,7 +60,14 @@ def get_next_action_n(self, net_out, info):

return next_action_n

def set_net_input(self, image):
def set_input(self, image):
# print(image)
transformed_image = self.image_transformer.preprocess('data', image)
self.net.blobs[self.input_layer_name].data[...] = transformed_image
self.net.blobs[self.input_layer_name].data[...] = transformed_image

def get_net_out(self):
begin = time.time()
net_out = self.forward()
end = time.time()
logger.debug('inference time %s', end - begin)
return net_out
File renamed without changes.
81 changes: 81 additions & 0 deletions drivers/deepdrive_tf/deep_driver_tf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import time

from driver_base import DriverBase
from universe.spaces.joystick_event import JoystickAxisXEvent, JoystickAxisZEvent
import logging
import numpy as np
from scipy.misc import imresize


logger = logging.getLogger()

import tensorflow as tf
import os
from drivers.deepdrive_tf.gtanet import GTANetModel

DIR_PATH = os.path.dirname(os.path.realpath(__file__))


class DeepDriverTF(DriverBase):
def __init__(self):
super(DeepDriverTF, self).__init__()
self.sess = None
self.net = None
self.image_var = None
self.net_out_var = None
self.image_shape = (227, 227, 3)
self.image = None
self.num_targets = 6

def load_net(self):
self.sess = tf.Session()
saver = tf.train.import_meta_graph(os.path.join(DIR_PATH, 'model.ckpt-20048.meta'))
saver.restore(self.sess, os.path.join(DIR_PATH, 'model.ckpt-20048'))
self.image_var = tf.placeholder(tf.float32, (None,) + self.image_shape)
self.net_out_var = tf.placeholder(tf.float32, (None, self.num_targets))
self.net = GTANetModel(self.image_var, is_training=False)

def get_next_action(self, net_out, info):
# spin, direction, speed, speed_change, steer, throttle = net_out['gtanet_fctop'][0]
pass
# steer = -float(steer)
# steer_dead_zone = 0.2
#
# # Add dead zones
# if steer > 0:
# steer += steer_dead_zone
# elif steer < 0:
# steer -= steer_dead_zone
#
# logger.debug('steer %f', steer)
# x_axis_event = JoystickAxisXEvent(steer)
# if 'n' in info and 'speed' in info['n'][0]:
# current_speed = info['n'][0]['speed']
# desired_speed = speed / 0.05 # Denormalize per deep_drive.h in deepdrive-caffe
# if desired_speed < current_speed:
# logger.debug('braking')
# throttle = self.throttle - (current_speed - desired_speed) * 0.085 # Magic number
# throttle = max(throttle, 0.0)
# else:
# throttle += 13. / 50. # Joystick dead zone
#
# z_axis_event = JoystickAxisZEvent(float(throttle))
# logging.debug('throttle %s', throttle)
# else:
# z_axis_event = JoystickAxisZEvent(0)
# logging.warn('cannot determine speed of car, coasting')
# next_action_n = [[x_axis_event, z_axis_event]]
#
# self.throttle = throttle
# self.steer = steer
# return next_action_n

def set_input(self, image):
self.image = imresize(image, self.image_shape).astype(np.float32, copy=False)

def get_net_out(self):
begin = time.time()
net_out = self.sess.run(self.net.p, feed_dict={self.image_var: self.image.reshape(1, 227, 227, 3)})
end = time.time()
logger.debug('inference time %s', end - begin)
return net_out
28 changes: 28 additions & 0 deletions drivers/deepdrive_tf/gtanet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import tensorflow as tf
from drivers.deepdrive_tf.layers import conv2d, max_pool_2x2, linear, lrn


class GTANetModel(object):
def __init__(self, x, num_targets=6, is_training=True):
self.x = x
conv1 = tf.nn.relu(conv2d(x, "conv1", 96, 11, 4, 1))
lrn1 = lrn(conv1)
maxpool1 = max_pool_2x2(lrn1)
conv2 = tf.nn.relu(conv2d(maxpool1, "conv2", 256, 5, 1, 2))
lrn2 = lrn(conv2)
maxpool2 = max_pool_2x2(lrn2)
conv3 = tf.nn.relu(conv2d(maxpool2, "conv3", 384, 3, 1, 1))
conv4 = tf.nn.relu(conv2d(conv3, "conv4", 384, 3, 1, 2))
conv5 = tf.nn.relu(conv2d(conv4, "conv5", 256, 3, 1, 2))
maxpool5 = max_pool_2x2(conv5)
fc6 = tf.nn.relu(linear(maxpool5, "fc6", 4096))
if is_training:
fc6 = tf.nn.dropout(fc6, 0.5)
fc7 = tf.nn.relu(linear(fc6, "fc7", 4096))
if is_training:
fc7 = tf.nn.dropout(fc7, 0.95)
fc8 = linear(fc7, "fc8", num_targets)

self.p = fc8
self.global_step = tf.get_variable("global_step", [], tf.int32, initializer=tf.zeros_initializer,
trainable=False)
39 changes: 39 additions & 0 deletions drivers/deepdrive_tf/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import numpy as np
import tensorflow as tf

def conv(input, kernel, biases, k_h, k_w, c_o, s_h, s_w, padding="VALID", group=1):
'''From https://github.com/ethereon/caffe-tensorflow
'''
c_i = input.get_shape()[-1]
assert c_i % group == 0
assert c_o % group == 0
def convolve(i, k):
return tf.nn.conv2d(i, k, [1, s_h, s_w, 1], padding=padding)

if group == 1:
conv = convolve(input, kernel)
else:
input_groups = tf.split(3, group, input)
kernel_groups = tf.split(3, group, kernel)
output_groups = [convolve(i, k) for i, k in zip(input_groups, kernel_groups)]
conv = tf.concat(3, output_groups)
return tf.reshape(tf.nn.bias_add(conv, biases), [-1] + conv.get_shape().as_list()[1:])

def conv2d(x, name, num_features, kernel_size, stride, group):
input_features = x.get_shape()[3]
w = tf.get_variable(name + "_W", [kernel_size, kernel_size, input_features // group, num_features])
b = tf.get_variable(name + "_b", [num_features])
return conv(x, w, b, kernel_size, kernel_size, num_features, stride, stride, padding="SAME", group=group)

def linear(x, name, size):
input_size = np.prod(list(map(int, x.get_shape()[1:])))
x = tf.reshape(x, [-1, input_size])
w = tf.get_variable(name + "_W", [input_size, size], initializer=tf.random_normal_initializer(0.0, 0.005))
b = tf.get_variable(name + "_b", [size], initializer=tf.zeros_initializer)
return tf.matmul(x, w) + b

def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='VALID')

def lrn(x):
return tf.nn.local_response_normalization(x, depth_radius=2, alpha=2e-05, beta=0.75, bias=1.0)
8 changes: 4 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

# if not os.getenv("PYPROFILE_FREQUENCY"):
# pyprofile.profile.print_frequency = 5
from drivers.deepdrive_tf.deep_driver_tf import DeepDriverTF

logger = logging.getLogger()
extra_logger = logging.getLogger('universe')
Expand Down Expand Up @@ -82,9 +83,8 @@ def main():

if args.driver == 'DeepDriver':
driver = DeepDriver()
elif args.driver == 'pt':
from drivers.pt.pt_driver import PTDriverBase
driver = PTDriverBase()
elif args.driver == 'DeepDriverTF':
driver = DeepDriverTF()
else:
raise Exception('That driver is not available')

Expand Down Expand Up @@ -112,7 +112,7 @@ def main():
# duration of the reset.
env.render()

action_n = driver.process_step(observation_n, reward_n, done_n, info)
action_n = driver.step(observation_n, reward_n, done_n, info)

if args.custom_camera:
# Sending this every step is probably overkill
Expand Down

0 comments on commit 3f91375

Please sign in to comment.