Skip to content

Commit

Permalink
Merge pull request tensorflow#2690 from tensorflow/tf-data
Browse files Browse the repository at this point in the history
Changing tf.contrib.data to tf.data for release of tf 1.4
  • Loading branch information
k-w-w authored Nov 7, 2017
2 parents 4cfa0d3 + ae5adb5 commit f88def2
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 47 deletions.
8 changes: 4 additions & 4 deletions official/mnist/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@


def input_fn(is_training, filename, batch_size=1, num_epochs=1):
"""A simple input_fn using the contrib.data input pipeline."""
"""A simple input_fn using the tf.data input pipeline."""

def example_parser(serialized_example):
"""Parses a single tf.Example into image and label tensors."""
Expand All @@ -71,8 +71,9 @@ def example_parser(serialized_example):
label = tf.cast(features['label'], tf.int32)
return image, tf.one_hot(label, 10)

dataset = tf.contrib.data.TFRecordDataset([filename])
dataset = tf.data.TFRecordDataset([filename])

# Apply dataset transformations
if is_training:
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance. Because MNIST is
Expand All @@ -84,8 +85,7 @@ def example_parser(serialized_example):
dataset = dataset.repeat(num_epochs)

# Map example_parser over dataset, and batch results by up to batch_size
dataset = dataset.map(
example_parser, num_threads=1, output_buffer_size=batch_size)
dataset = dataset.map(example_parser).prefetch(batch_size)
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
images, labels = iterator.get_next()
Expand Down
69 changes: 35 additions & 34 deletions official/resnet/cifar10_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,11 @@
'validation': 10000,
}

_SHUFFLE_BUFFER = 20000


def record_dataset(filenames):
"""Returns an input pipeline Dataset from `filenames`."""
record_bytes = _HEIGHT * _WIDTH * _DEPTH + 1
return tf.contrib.data.FixedLengthRecordDataset(filenames, record_bytes)
return tf.data.FixedLengthRecordDataset(filenames, record_bytes)


def get_filenames(is_training, data_dir):
Expand All @@ -97,74 +95,77 @@ def get_filenames(is_training, data_dir):
return [os.path.join(data_dir, 'test_batch.bin')]


def dataset_parser(value):
"""Parse a CIFAR-10 record from value."""
def parse_record(raw_record):
"""Parse CIFAR-10 image and label from a raw record."""
# Every record consists of a label followed by the image, with a fixed number
# of bytes for each.
label_bytes = 1
image_bytes = _HEIGHT * _WIDTH * _DEPTH
record_bytes = label_bytes + image_bytes

# Convert from a string to a vector of uint8 that is record_bytes long.
raw_record = tf.decode_raw(value, tf.uint8)
# Convert bytes to a vector of uint8 that is record_bytes long.
record_vector = tf.decode_raw(raw_record, tf.uint8)

# The first byte represents the label, which we convert from uint8 to int32.
label = tf.cast(raw_record[0], tf.int32)
# The first byte represents the label, which we convert from uint8 to int32
# and then to one-hot.
label = tf.cast(record_vector[0], tf.int32)
label = tf.one_hot(label, _NUM_CLASSES)

# The remaining bytes after the label represent the image, which we reshape
# from [depth * height * width] to [depth, height, width].
depth_major = tf.reshape(raw_record[label_bytes:record_bytes],
[_DEPTH, _HEIGHT, _WIDTH])
depth_major = tf.reshape(
record_vector[label_bytes:record_bytes], [_DEPTH, _HEIGHT, _WIDTH])

# Convert from [depth, height, width] to [height, width, depth], and cast as
# float32.
image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32)

return image, tf.one_hot(label, _NUM_CLASSES)
return image, label


def train_preprocess_fn(image, label):
"""Preprocess a single training image of layout [height, width, depth]."""
# Resize the image to add four extra pixels on each side.
image = tf.image.resize_image_with_crop_or_pad(image, _HEIGHT + 8, _WIDTH + 8)
def preprocess_image(image, is_training):
"""Preprocess a single image of layout [height, width, depth]."""
if is_training:
# Resize the image to add four extra pixels on each side.
image = tf.image.resize_image_with_crop_or_pad(
image, _HEIGHT + 8, _WIDTH + 8)

# Randomly crop a [_HEIGHT, _WIDTH] section of the image.
image = tf.random_crop(image, [_HEIGHT, _WIDTH, _DEPTH])
# Randomly crop a [_HEIGHT, _WIDTH] section of the image.
image = tf.random_crop(image, [_HEIGHT, _WIDTH, _DEPTH])

# Randomly flip the image horizontally.
image = tf.image.random_flip_left_right(image)
# Randomly flip the image horizontally.
image = tf.image.random_flip_left_right(image)

return image, label
# Subtract off the mean and divide by the variance of the pixels.
image = tf.image.per_image_standardization(image)
return image


def input_fn(is_training, data_dir, batch_size, num_epochs=1):
"""Input_fn using the contrib.data input pipeline for CIFAR-10 dataset.
"""Input_fn using the tf.data input pipeline for CIFAR-10 dataset.
Args:
is_training: A boolean denoting whether the input is for training.
data_dir: The directory containing the input data.
batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset.
Returns:
A tuple of images and labels.
"""
dataset = record_dataset(get_filenames(is_training, data_dir))
dataset = dataset.map(dataset_parser, num_threads=1,
output_buffer_size=2 * batch_size)

# For training, preprocess the image and shuffle.
if is_training:
dataset = dataset.map(train_preprocess_fn, num_threads=1,
output_buffer_size=2 * batch_size)

# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance.
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
# randomness, while smaller sizes have better performance. Because CIFAR-10
# is a relatively small dataset, we choose to shuffle the full epoch.
dataset = dataset.shuffle(buffer_size=_NUM_IMAGES['train'])

# Subtract off the mean and divide by the variance of the pixels.
dataset = dataset.map(parse_record)
dataset = dataset.map(
lambda image, label: (tf.image.per_image_standardization(image), label),
num_threads=1,
output_buffer_size=2 * batch_size)
lambda image, label: (preprocess_image(image, is_training), label))

dataset = dataset.prefetch(2 * batch_size)

# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
Expand Down
2 changes: 1 addition & 1 deletion official/resnet/cifar10_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_dataset_input_fn(self):
data_file.close()

fake_dataset = cifar10_main.record_dataset(filename)
fake_dataset = fake_dataset.map(cifar10_main.dataset_parser)
fake_dataset = fake_dataset.map(cifar10_main.parse_record)
image, label = fake_dataset.make_one_shot_iterator().get_next()

self.assertEqual(label.get_shape().as_list(), [10])
Expand Down
12 changes: 6 additions & 6 deletions official/resnet/imagenet_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,23 +134,23 @@ def dataset_parser(value, is_training):

def input_fn(is_training, data_dir, batch_size, num_epochs=1):
"""Input function which provides batches for train or eval."""
dataset = tf.contrib.data.Dataset.from_tensor_slices(
dataset = tf.data.Dataset.from_tensor_slices(
filenames(is_training, data_dir))

if is_training:
dataset = dataset.shuffle(buffer_size=_FILE_SHUFFLE_BUFFER)

dataset = dataset.flat_map(tf.contrib.data.TFRecordDataset)

dataset = dataset.map(lambda value: dataset_parser(value, is_training),
num_threads=5,
output_buffer_size=batch_size)
dataset = dataset.flat_map(tf.data.TFRecordDataset)

if is_training:
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance.
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)

dataset = dataset.map(lambda value: dataset_parser(value, is_training),
num_parallel_calls=5)
dataset = dataset.prefetch(batch_size)

# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
dataset = dataset.repeat(num_epochs)
Expand Down
5 changes: 3 additions & 2 deletions official/wide_deep/wide_deep.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,13 @@ def parse_csv(value):
return features, tf.equal(labels, '>50K')

# Extract lines from input files using the Dataset API.
dataset = tf.contrib.data.TextLineDataset(data_file)
dataset = dataset.map(parse_csv, num_threads=5)
dataset = tf.data.TextLineDataset(data_file)

if shuffle:
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)

dataset = dataset.map(parse_csv, num_parallel_calls=5)

# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
dataset = dataset.repeat(num_epochs)
Expand Down

0 comments on commit f88def2

Please sign in to comment.