Skip to content

Commit

Permalink
Separate parse_and_preprocess into two different dataset.map calls, w…
Browse files Browse the repository at this point in the history
…hich also keeps tests passing
  • Loading branch information
nealwu committed Nov 7, 2017
1 parent 807d6bd commit 6e52c27
Showing 1 changed file with 19 additions and 24 deletions.
43 changes: 19 additions & 24 deletions official/resnet/cifar10_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,45 +108,38 @@ def parse_record(raw_record):
# Convert bytes to a vector of uint8 that is record_bytes long.
record_vector = tf.decode_raw(raw_record, tf.uint8)

# The first byte represents the label, which we convert from uint8 to int32.
# The first byte represents the label, which we convert from uint8 to int32
# and then to one-hot.
label = tf.cast(record_vector[0], tf.int32)
label = tf.one_hot(label, _NUM_CLASSES)

# The remaining bytes after the label represent the image, which we reshape
# from [depth * height * width] to [depth, height, width].
depth_major = tf.reshape(record_vector[label_bytes:record_bytes],
[_DEPTH, _HEIGHT, _WIDTH])
depth_major = tf.reshape(
record_vector[label_bytes:record_bytes], [_DEPTH, _HEIGHT, _WIDTH])

# Convert from [depth, height, width] to [height, width, depth], and cast as
# float32.
image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32)

return image, tf.one_hot(label, _NUM_CLASSES)


def train_preprocess_fn(image):
"""Preprocess a single training image of layout [height, width, depth]."""
# Resize the image to add four extra pixels on each side.
image = tf.image.resize_image_with_crop_or_pad(image, _HEIGHT + 8, _WIDTH + 8)

# Randomly crop a [_HEIGHT, _WIDTH] section of the image.
image = tf.random_crop(image, [_HEIGHT, _WIDTH, _DEPTH])

# Randomly flip the image horizontally.
image = tf.image.random_flip_left_right(image)
return image, label

return image

def preprocess_image(image, is_training):
"""Preprocess a single image of layout [height, width, depth]."""
if is_training:
# Resize the image to add four extra pixels on each side.
image = tf.image.resize_image_with_crop_or_pad(image, _HEIGHT + 8, _WIDTH + 8)

def parse_and_preprocess(record, is_training):
"""Parse and preprocess records in the CIFAR-10 dataset."""
image, label = parse_record(record)
# Randomly crop a [_HEIGHT, _WIDTH] section of the image.
image = tf.random_crop(image, [_HEIGHT, _WIDTH, _DEPTH])

if is_training:
image = train_preprocess_fn(image)
# Randomly flip the image horizontally.
image = tf.image.random_flip_left_right(image)

# Subtract off the mean and divide by the variance of the pixels.
image = tf.image.per_image_standardization(image)
return image, label
return image


def input_fn(is_training, data_dir, batch_size, num_epochs=1):
Expand All @@ -168,8 +161,10 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
# randomness, while smaller sizes have better performance.
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)

dataset = dataset.map(parse_record)
dataset = dataset.map(
lambda record: parse_and_preprocess(record, is_training))
lambda image, label: (preprocess_image(image, is_training), label))

dataset = dataset.prefetch(2 * batch_size)

# We call repeat after shuffling, rather than before, to prevent separate
Expand Down

0 comments on commit 6e52c27

Please sign in to comment.