Skip to content

Commit

Permalink
Improve image processing (tensorflow#45)
Browse files Browse the repository at this point in the history
* improve image processing performance for Inception.
  • Loading branch information
jmchen-g authored and mrry committed Apr 13, 2016
1 parent 84b58a6 commit 5d7612c
Showing 1 changed file with 57 additions and 26 deletions.
83 changes: 57 additions & 26 deletions inception/inception/image_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from __future__ import division
from __future__ import print_function


import tensorflow as tf

FLAGS = tf.app.flags.FLAGS
Expand All @@ -52,6 +51,8 @@
tf.app.flags.DEFINE_integer('num_preprocess_threads', 4,
"""Number of preprocessing threads per tower. """
"""Please make this a multiple of 4.""")
tf.app.flags.DEFINE_integer('num_readers', 4,
"""Number of parallel readers during train.""")

# Images are preprocessed asynchronously using multiple threads specifed by
# --num_preprocss_threads and the resulting processed images are stored in a
Expand Down Expand Up @@ -97,7 +98,8 @@ def inputs(dataset, batch_size=None, num_preprocess_threads=None):
with tf.device('/cpu:0'):
images, labels = batch_inputs(
dataset, batch_size, train=False,
num_preprocess_threads=num_preprocess_threads)
num_preprocess_threads=num_preprocess_threads,
num_readers=1)

return images, labels

Expand Down Expand Up @@ -130,7 +132,8 @@ def distorted_inputs(dataset, batch_size=None, num_preprocess_threads=None):
with tf.device('/cpu:0'):
images, labels = batch_inputs(
dataset, batch_size, train=True,
num_preprocess_threads=num_preprocess_threads)
num_preprocess_threads=num_preprocess_threads,
num_readers=FLAGS.num_readers)
return images, labels


Expand Down Expand Up @@ -401,7 +404,8 @@ def parse_example_proto(example_serialized):
return features['image/encoded'], label, bbox, features['image/class/text']


def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None):
def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None,
num_readers=1):
"""Contruct batches of training or evaluation examples from the image dataset.
Args:
Expand All @@ -410,6 +414,7 @@ def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None):
batch_size: integer
train: boolean
num_preprocess_threads: integer, total number of preprocessing threads
num_readers: integer, number of parallel readers
Returns:
images: 4-D float Tensor of a batch of images
Expand All @@ -422,26 +427,28 @@ def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None):
data_files = dataset.data_files()
if data_files is None:
raise ValueError('No data files found for this dataset')
filename_queue = tf.train.string_input_producer(data_files, capacity=16)

# Create filename_queue
if train:
filename_queue = tf.train.string_input_producer(data_files,
shuffle=True,
capacity=16)
else:
filename_queue = tf.train.string_input_producer(data_files,
shuffle=False,
capacity=1)
if num_preprocess_threads is None:
num_preprocess_threads = FLAGS.num_preprocess_threads

if num_preprocess_threads % 4:
raise ValueError('Please make num_preprocess_threads a multiple '
'of 4 (%d % 4 != 0).', num_preprocess_threads)
# Create a subgraph with its own reader (but sharing the
# filename_queue) for each preprocessing thread.
images_and_labels = []
for thread_id in range(num_preprocess_threads):
reader = dataset.reader()
_, example_serialized = reader.read(filename_queue)

# Parse a serialized Example proto to extract the image and metadata.
image_buffer, label_index, bbox, _ = parse_example_proto(
example_serialized)
image = image_preprocessing(image_buffer, bbox, train, thread_id)
images_and_labels.append([image, label_index])
if num_readers is None:
num_readers = FLAGS.num_readers

if num_readers < 1:
raise ValueError('Please make num_readers at least 1')

# Approximate number of examples per shard.
examples_per_shard = 1024
Expand All @@ -451,19 +458,43 @@ def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None):
# The default input_queue_memory_factor is 16 implying a shuffling queue
# size: examples_per_shard * 16 * 1MB = 17.6GB
min_queue_examples = examples_per_shard * FLAGS.input_queue_memory_factor

# Create a queue that produces the examples in batches after shuffling.
if train:
images, label_index_batch = tf.train.shuffle_batch_join(
images_and_labels,
batch_size=batch_size,
examples_queue = tf.RandomShuffleQueue(
capacity=min_queue_examples + 3 * batch_size,
min_after_dequeue=min_queue_examples)
min_after_dequeue=min_queue_examples,
dtypes=[tf.string])
else:
examples_queue = tf.FIFOQueue(
capacity=examples_per_shard + 3 * batch_size,
dtypes=[tf.string])

# Create multiple readers to populate the queue of examples.
if num_readers > 1:
enqueue_ops = []
for _ in range(num_readers):
reader = dataset.reader()
_, value = reader.read(filename_queue)
enqueue_ops.append(examples_queue.enqueue([value]))

tf.train.queue_runner.add_queue_runner(
tf.train.queue_runner.QueueRunner(examples_queue, enqueue_ops))
example_serialized = examples_queue.dequeue()
else:
images, label_index_batch = tf.train.batch_join(
images_and_labels,
batch_size=batch_size,
capacity=min_queue_examples + 3 * batch_size)
reader = dataset.reader()
_, example_serialized = reader.read(filename_queue)

images_and_labels = []
for thread_id in range(num_preprocess_threads):
# Parse a serialized Example proto to extract the image and metadata.
image_buffer, label_index, bbox, _ = parse_example_proto(
example_serialized)
image = image_preprocessing(image_buffer, bbox, train, thread_id)
images_and_labels.append([image, label_index])

images, label_index_batch = tf.train.batch_join(
images_and_labels,
batch_size=batch_size,
capacity=2 * num_preprocess_threads * batch_size)

# Reshape images into these desired dimensions.
height = FLAGS.image_size
Expand Down

0 comments on commit 5d7612c

Please sign in to comment.