Skip to content

Commit

Permalink
Changing tf.contrib.data to tf.data for release of tf 1.4
Browse files Browse the repository at this point in the history
  • Loading branch information
k-w-w committed Nov 6, 2017
1 parent 4cfa0d3 commit 1f6b3d7
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 18 deletions.
8 changes: 6 additions & 2 deletions official/mnist/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@


def input_fn(is_training, filename, batch_size=1, num_epochs=1):
"""A simple input_fn using the contrib.data input pipeline."""
"""A simple input_fn using the tf.data input pipeline."""

def example_parser(serialized_example):
"""Parses a single tf.Example into image and label tensors."""
Expand All @@ -71,8 +71,12 @@ def example_parser(serialized_example):
label = tf.cast(features['label'], tf.int32)
return image, tf.one_hot(label, 10)

dataset = tf.contrib.data.TFRecordDataset([filename])
dataset = tf.data.TFRecordDataset([filename])

# Parse each example in the dataset
dataset = dataset.map(example_parser)

# Apply dataset transformations
if is_training:
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance. Because MNIST is
Expand Down
14 changes: 5 additions & 9 deletions official/resnet/cifar10_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
def record_dataset(filenames):
"""Returns an input pipeline Dataset from `filenames`."""
record_bytes = _HEIGHT * _WIDTH * _DEPTH + 1
return tf.contrib.data.FixedLengthRecordDataset(filenames, record_bytes)
return tf.data.FixedLengthRecordDataset(filenames, record_bytes)


def get_filenames(is_training, data_dir):
Expand Down Expand Up @@ -138,7 +138,7 @@ def train_preprocess_fn(image, label):


def input_fn(is_training, data_dir, batch_size, num_epochs=1):
"""Input_fn using the contrib.data input pipeline for CIFAR-10 dataset.
"""Input_fn using the tf.data input pipeline for CIFAR-10 dataset.
Args:
is_training: A boolean denoting whether the input is for training.
Expand All @@ -148,23 +148,19 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
A tuple of images and labels.
"""
dataset = record_dataset(get_filenames(is_training, data_dir))
dataset = dataset.map(dataset_parser, num_threads=1,
output_buffer_size=2 * batch_size)
dataset = dataset.map(dataset_parser)

# For training, preprocess the image and shuffle.
if is_training:
dataset = dataset.map(train_preprocess_fn, num_threads=1,
output_buffer_size=2 * batch_size)
dataset = dataset.map(train_preprocess_fn)

# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance.
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)

# Subtract off the mean and divide by the variance of the pixels.
dataset = dataset.map(
lambda image, label: (tf.image.per_image_standardization(image), label),
num_threads=1,
output_buffer_size=2 * batch_size)
lambda image, label: (tf.image.per_image_standardization(image), label))

# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
Expand Down
7 changes: 3 additions & 4 deletions official/resnet/imagenet_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,17 +134,16 @@ def dataset_parser(value, is_training):

def input_fn(is_training, data_dir, batch_size, num_epochs=1):
"""Input function which provides batches for train or eval."""
dataset = tf.contrib.data.Dataset.from_tensor_slices(
dataset = tf.data.Dataset.from_tensor_slices(
filenames(is_training, data_dir))

if is_training:
dataset = dataset.shuffle(buffer_size=_FILE_SHUFFLE_BUFFER)

dataset = dataset.flat_map(tf.contrib.data.TFRecordDataset)
dataset = dataset.flat_map(tf.data.TFRecordDataset)

dataset = dataset.map(lambda value: dataset_parser(value, is_training),
num_threads=5,
output_buffer_size=batch_size)
num_parallel_calls=5)

if is_training:
# When choosing shuffle buffer sizes, larger sizes result in better
Expand Down
5 changes: 2 additions & 3 deletions official/wide_deep/wide_deep.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ def parse_csv(value):
return features, tf.equal(labels, '>50K')

# Extract lines from input files using the Dataset API.
dataset = tf.contrib.data.TextLineDataset(data_file)
dataset = dataset.map(parse_csv, num_threads=5)
dataset = tf.data.TextLineDataset(data_file)
dataset = dataset.map(parse_csv, num_parallel_calls=5)

if shuffle:
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
Expand All @@ -193,7 +193,6 @@ def parse_csv(value):
features, labels = iterator.get_next()
return features, labels


def main(unused_argv):
# Clean up the model directory if present
shutil.rmtree(FLAGS.model_dir, ignore_errors=True)
Expand Down

0 comments on commit 1f6b3d7

Please sign in to comment.