Skip to content

Commit

Permalink
Use FLAGS in main functions only + Updates to shuffling (tensorflow#2601
Browse files Browse the repository at this point in the history
)
  • Loading branch information
nealwu authored Oct 27, 2017
1 parent edcd29f commit 4702de2
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 87 deletions.
18 changes: 9 additions & 9 deletions official/mnist/convert_to_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def convert_to(data_set, name):
def convert_to(dataset, name, directory):
"""Converts a dataset to TFRecords."""
images = data_set.images
labels = data_set.labels
num_examples = data_set.num_examples
images = dataset.images
labels = dataset.labels
num_examples = dataset.num_examples

if images.shape[0] != num_examples:
raise ValueError('Images size %d does not match label size %d.' %
Expand All @@ -63,7 +63,7 @@ def convert_to(data_set, name):
cols = images.shape[2]
depth = images.shape[3]

filename = os.path.join(FLAGS.directory, name + '.tfrecords')
filename = os.path.join(directory, name + '.tfrecords')
print('Writing', filename)
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
Expand All @@ -80,15 +80,15 @@ def convert_to(data_set, name):

def main(unused_argv):
# Get the data.
data_sets = mnist.read_data_sets(FLAGS.directory,
datasets = mnist.read_data_sets(FLAGS.directory,
dtype=tf.uint8,
reshape=False,
validation_size=FLAGS.validation_size)

# Convert to Examples and write the result to TFRecords.
convert_to(data_sets.train, 'train')
convert_to(data_sets.validation, 'validation')
convert_to(data_sets.test, 'test')
convert_to(datasets.train, 'train', FLAGS.directory)
convert_to(datasets.validation, 'validation', FLAGS.directory)
convert_to(datasets.test, 'test', FLAGS.directory)


if __name__ == '__main__':
Expand Down
53 changes: 26 additions & 27 deletions official/mnist/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
}


def input_fn(mode, batch_size=1):
def input_fn(is_training, filename, batch_size=1, num_epochs=1):
"""A simple input_fn using the contrib.data input pipeline."""

def example_parser(serialized_example):
Expand All @@ -71,21 +71,15 @@ def example_parser(serialized_example):
label = tf.cast(features['label'], tf.int32)
return image, tf.one_hot(label, 10)

if mode == tf.estimator.ModeKeys.TRAIN:
tfrecords_file = os.path.join(FLAGS.data_dir, 'train.tfrecords')
else:
assert mode == tf.estimator.ModeKeys.EVAL, 'invalid mode'
tfrecords_file = os.path.join(FLAGS.data_dir, 'test.tfrecords')
dataset = tf.contrib.data.TFRecordDataset([filename])

assert tf.gfile.Exists(tfrecords_file), (
'Run convert_to_records.py first to convert the MNIST data to TFRecord '
'file format.')
if is_training:
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance. Because MNIST is
# a small dataset, we can easily shuffle the full epoch.
dataset = dataset.shuffle(buffer_size=_NUM_IMAGES['train'])

dataset = tf.contrib.data.TFRecordDataset([tfrecords_file])

# For training, repeat the dataset forever
if mode == tf.estimator.ModeKeys.TRAIN:
dataset = dataset.repeat()
dataset = dataset.repeat(num_epochs)

# Map example_parser over dataset, and batch results by up to batch_size
dataset = dataset.map(
Expand All @@ -96,13 +90,12 @@ def example_parser(serialized_example):
return images, labels


def mnist_model(inputs, mode):
def mnist_model(inputs, mode, data_format):
"""Takes the MNIST inputs and mode and outputs a tensor of logits."""
# Input Layer
# Reshape X to 4-D tensor: [batch_size, width, height, channels]
# MNIST images are 28x28 pixels, and have one color channel
inputs = tf.reshape(inputs, [-1, 28, 28, 1])
data_format = FLAGS.data_format

if data_format is None:
# When running on GPU, transpose the data from channels_last (NHWC) to
Expand Down Expand Up @@ -177,9 +170,9 @@ def mnist_model(inputs, mode):
return logits


def mnist_model_fn(features, labels, mode):
def mnist_model_fn(features, labels, mode, params):
"""Model function for MNIST."""
logits = mnist_model(features, mode)
logits = mnist_model(features, mode, params['data_format'])

predictions = {
'classes': tf.argmax(input=logits, axis=1),
Expand Down Expand Up @@ -215,30 +208,36 @@ def mnist_model_fn(features, labels, mode):


def main(unused_argv):
# Make sure that training and testing data have been converted.
train_file = os.path.join(FLAGS.data_dir, 'train.tfrecords')
test_file = os.path.join(FLAGS.data_dir, 'test.tfrecords')
assert (tf.gfile.Exists(train_file) and tf.gfile.Exists(test_file)), (
'Run convert_to_records.py first to convert the MNIST data to TFRecord '
'file format.')

# Create the Estimator
mnist_classifier = tf.estimator.Estimator(
model_fn=mnist_model_fn, model_dir=FLAGS.model_dir)
model_fn=mnist_model_fn, model_dir=FLAGS.model_dir,
params={'data_format': FLAGS.data_format})

# Train the model
# Set up training hook that logs the training accuracy every 100 steps.
tensors_to_log = {
'train_accuracy': 'train_accuracy'
}

logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=100)

batches_per_epoch = _NUM_IMAGES['train'] / FLAGS.batch_size

# Train the model
mnist_classifier.train(
input_fn=lambda: input_fn(tf.estimator.ModeKeys.TRAIN, FLAGS.batch_size),
steps=FLAGS.train_epochs * batches_per_epoch,
input_fn=lambda: input_fn(
True, train_file, FLAGS.batch_size, FLAGS.train_epochs),
hooks=[logging_hook])

# Evaluate the model and print results
eval_results = mnist_classifier.evaluate(
input_fn=lambda: input_fn(tf.estimator.ModeKeys.EVAL))
input_fn=lambda: input_fn(False, test_file, FLAGS.batch_size))
print()
print('Evaluation results:\n %s' % eval_results)
print('Evaluation results:\n\t%s' % eval_results)


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions official/mnist/mnist_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def input_fn(self):
def mnist_model_fn_helper(self, mode):
features, labels = self.input_fn()
image_count = features.shape[0]
spec = mnist.mnist_model_fn(features, labels, mode)
spec = mnist.mnist_model_fn(
features, labels, mode, {'data_format': 'channels_last'})

predictions = spec.predictions
self.assertAllEqual(predictions['probabilities'].shape, (image_count, 10))
Expand Down Expand Up @@ -65,5 +66,4 @@ def test_mnist_model_fn_predict_mode(self):


if __name__ == '__main__':
mnist.FLAGS = mnist.parser.parse_args()
tf.test.main()
44 changes: 25 additions & 19 deletions official/resnet/cifar10_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,18 @@
'validation': 10000,
}

_SHUFFLE_BUFFER = 20000


def record_dataset(filenames):
"""Returns an input pipeline Dataset from `filenames`."""
record_bytes = _HEIGHT * _WIDTH * _DEPTH + 1
return tf.contrib.data.FixedLengthRecordDataset(filenames, record_bytes)


def get_filenames(is_training):
def get_filenames(is_training, data_dir):
"""Returns a list of filenames."""
data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
data_dir = os.path.join(data_dir, 'cifar-10-batches-bin')

assert os.path.exists(data_dir), (
'Run cifar10_download_and_extract.py first to download and extract the '
Expand Down Expand Up @@ -135,7 +137,7 @@ def train_preprocess_fn(image, label):
return image, label


def input_fn(is_training, num_epochs=1):
def input_fn(is_training, data_dir, batch_size, num_epochs=1):
"""Input_fn using the contrib.data input pipeline for CIFAR-10 dataset.
Args:
Expand All @@ -145,42 +147,41 @@ def input_fn(is_training, num_epochs=1):
Returns:
A tuple of images and labels.
"""
dataset = record_dataset(get_filenames(is_training))
dataset = record_dataset(get_filenames(is_training, data_dir))
dataset = dataset.map(dataset_parser, num_threads=1,
output_buffer_size=2 * FLAGS.batch_size)
output_buffer_size=2 * batch_size)

# For training, preprocess the image and shuffle.
if is_training:
dataset = dataset.map(train_preprocess_fn, num_threads=1,
output_buffer_size=2 * FLAGS.batch_size)
output_buffer_size=2 * batch_size)

# Ensure that the capacity is sufficiently large to provide good random
# shuffling.
buffer_size = int(0.4 * _NUM_IMAGES['train'])
dataset = dataset.shuffle(buffer_size=buffer_size)
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance.
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)

# Subtract off the mean and divide by the variance of the pixels.
dataset = dataset.map(
lambda image, label: (tf.image.per_image_standardization(image), label),
num_threads=1,
output_buffer_size=2 * FLAGS.batch_size)
output_buffer_size=2 * batch_size)

dataset = dataset.repeat(num_epochs)

# Batch results by up to batch_size, and then fetch the tuple from the
# iterator.
iterator = dataset.batch(FLAGS.batch_size).make_one_shot_iterator()
iterator = dataset.batch(batch_size).make_one_shot_iterator()
images, labels = iterator.get_next()

return images, labels


def cifar10_model_fn(features, labels, mode):
def cifar10_model_fn(features, labels, mode, params):
"""Model function for CIFAR-10."""
tf.summary.image('images', features, max_outputs=6)

network = resnet_model.cifar10_resnet_v2_generator(
FLAGS.resnet_size, _NUM_CLASSES, FLAGS.data_format)
params['resnet_size'], _NUM_CLASSES, params['data_format'])

inputs = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _DEPTH])
logits = network(inputs, mode == tf.estimator.ModeKeys.TRAIN)
Expand Down Expand Up @@ -208,8 +209,8 @@ def cifar10_model_fn(features, labels, mode):
if mode == tf.estimator.ModeKeys.TRAIN:
# Scale the learning rate linearly with the batch size. When the batch size
# is 128, the learning rate should be 0.1.
initial_learning_rate = 0.1 * FLAGS.batch_size / 128
batches_per_epoch = _NUM_IMAGES['train'] / FLAGS.batch_size
initial_learning_rate = 0.1 * params['batch_size'] / 128
batches_per_epoch = _NUM_IMAGES['train'] / params['batch_size']
global_step = tf.train.get_or_create_global_step()

# Multiply the learning rate by 0.1 at 100, 150, and 200 epochs.
Expand Down Expand Up @@ -256,7 +257,12 @@ def main(unused_argv):
# Set up a RunConfig to only save checkpoints once per training cycle.
run_config = tf.estimator.RunConfig().replace(save_checkpoints_secs=1e9)
cifar_classifier = tf.estimator.Estimator(
model_fn=cifar10_model_fn, model_dir=FLAGS.model_dir, config=run_config)
model_fn=cifar10_model_fn, model_dir=FLAGS.model_dir, config=run_config,
params={
'resnet_size': FLAGS.resnet_size,
'data_format': FLAGS.data_format,
'batch_size': FLAGS.batch_size,
})

for _ in range(FLAGS.train_epochs // FLAGS.epochs_per_eval):
tensors_to_log = {
Expand All @@ -270,12 +276,12 @@ def main(unused_argv):

cifar_classifier.train(
input_fn=lambda: input_fn(
is_training=True, num_epochs=FLAGS.epochs_per_eval),
True, FLAGS.data_dir, FLAGS.batch_size, FLAGS.epochs_per_eval),
hooks=[logging_hook])

# Evaluate the model and print results
eval_results = cifar_classifier.evaluate(
input_fn=lambda: input_fn(is_training=False))
input_fn=lambda: input_fn(False, FLAGS.data_dir, FLAGS.batch_size))
print(eval_results)


Expand Down
19 changes: 12 additions & 7 deletions official/resnet/cifar10_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

tf.logging.set_verbosity(tf.logging.ERROR)

_BATCH_SIZE = 128


class BaseTest(tf.test.TestCase):

Expand Down Expand Up @@ -58,20 +60,25 @@ def test_dataset_input_fn(self):
self.assertAllEqual(pixel, np.array([0, 1, 2]))

def input_fn(self):
features = tf.random_uniform([FLAGS.batch_size, 32, 32, 3])
features = tf.random_uniform([_BATCH_SIZE, 32, 32, 3])
labels = tf.random_uniform(
[FLAGS.batch_size], maxval=9, dtype=tf.int32)
[_BATCH_SIZE], maxval=9, dtype=tf.int32)
return features, tf.one_hot(labels, 10)

def cifar10_model_fn_helper(self, mode):
features, labels = self.input_fn()
spec = cifar10_main.cifar10_model_fn(features, labels, mode)
spec = cifar10_main.cifar10_model_fn(
features, labels, mode, {
'resnet_size': 32,
'data_format': 'channels_last',
'batch_size': _BATCH_SIZE,
})

predictions = spec.predictions
self.assertAllEqual(predictions['probabilities'].shape,
(FLAGS.batch_size, 10))
(_BATCH_SIZE, 10))
self.assertEqual(predictions['probabilities'].dtype, tf.float32)
self.assertAllEqual(predictions['classes'].shape, (FLAGS.batch_size,))
self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
self.assertEqual(predictions['classes'].dtype, tf.int64)

if mode != tf.estimator.ModeKeys.PREDICT:
Expand All @@ -97,6 +104,4 @@ def test_cifar10_model_fn_predict_mode(self):


if __name__ == '__main__':
cifar10_main.FLAGS = cifar10_main.parser.parse_args()
FLAGS = cifar10_main.FLAGS
tf.test.main()
Loading

0 comments on commit 4702de2

Please sign in to comment.