Skip to content

Commit

Permalink
Add code
Browse files Browse the repository at this point in the history
  • Loading branch information
simonschuang committed May 21, 2019
0 parents commit 8503a1d
Show file tree
Hide file tree
Showing 9 changed files with 1,743 additions and 0 deletions.
523 changes: 523 additions & 0 deletions README.md

Large diffs are not rendered by default.

Empty file added __init__.py
Empty file.
113 changes: 113 additions & 0 deletions cifar10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""CIFAR-10 data set.
See http://www.cs.toronto.edu/~kriz/cifar.html.
"""
import os

import tensorflow as tf

HEIGHT = 32
WIDTH = 32
DEPTH = 3


class Cifar10DataSet(object):
"""Cifar10 data set.
Described by http://www.cs.toronto.edu/~kriz/cifar.html.
"""

def __init__(self, data_dir, subset='train', use_distortion=True):
self.data_dir = data_dir
self.subset = subset
self.use_distortion = use_distortion

def get_filenames(self):
if self.subset in ['train', 'validation', 'eval']:
return [os.path.join(self.data_dir, self.subset + '.tfrecords')]
else:
raise ValueError('Invalid data subset "%s"' % self.subset)

def parser(self, serialized_example):
"""Parses a single tf.Example into image and label tensors."""
# Dimensions of the images in the CIFAR-10 dataset.
# See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
# input format.
features = tf.parse_single_example(
serialized_example,
features={
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
})
image = tf.decode_raw(features['image'], tf.uint8)
image.set_shape([DEPTH * HEIGHT * WIDTH])

# Reshape from [depth * height * width] to [depth, height, width].
image = tf.cast(
tf.transpose(tf.reshape(image, [DEPTH, HEIGHT, WIDTH]), [1, 2, 0]),
tf.float32)
label = tf.cast(features['label'], tf.int32)

# Custom preprocessing.
image = self.preprocess(image)

return image, label

def make_batch(self, batch_size):
"""Read the images and labels from 'filenames'."""
filenames = self.get_filenames()
# Repeat infinitely.
dataset = tf.data.TFRecordDataset(filenames).repeat()

# Parse records.
dataset = dataset.map(
self.parser, num_parallel_calls=batch_size)

# Potentially shuffle records.
if self.subset == 'train':
min_queue_examples = int(
Cifar10DataSet.num_examples_per_epoch(self.subset) * 0.4)
# Ensure that the capacity is sufficiently large to provide good random
# shuffling.
dataset = dataset.shuffle(buffer_size=min_queue_examples + 3 * batch_size)

# Batch it up.
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
image_batch, label_batch = iterator.get_next()

return image_batch, label_batch

def preprocess(self, image):
"""Preprocess a single image in [height, width, depth] layout."""
if self.subset == 'train' and self.use_distortion:
# Pad 4 pixels on each dimension of feature map, done in mini-batch
image = tf.image.resize_image_with_crop_or_pad(image, 40, 40)
image = tf.random_crop(image, [HEIGHT, WIDTH, DEPTH])
image = tf.image.random_flip_left_right(image)
return image

@staticmethod
def num_examples_per_epoch(subset='train'):
if subset == 'train':
return 45000
elif subset == 'validation':
return 5000
elif subset == 'eval':
return 10000
else:
raise ValueError('Invalid data subset "%s"' % subset)
Loading

0 comments on commit 8503a1d

Please sign in to comment.