Skip to content

Commit

Permalink
Project import generated by Copybara.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 177165761
  • Loading branch information
joel-shor committed Nov 28, 2017
1 parent 220772b commit b6907e8
Show file tree
Hide file tree
Showing 18 changed files with 983 additions and 74 deletions.
39 changes: 19 additions & 20 deletions research/gan/cifar/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,26 +37,25 @@ def test_get_image_grid(self):
num_classes=3,
num_images_per_class=1)

def test_get_inception_scores(self):
# Mock `inception_score` which is expensive.
with mock.patch.object(
util.tfgan.eval, 'inception_score') as mock_inception_score:
mock_inception_score.return_value = 1.0
util.get_inception_scores(
tf.placeholder(tf.float32, shape=[None, 28, 28, 3]),
batch_size=100,
num_inception_images=10)

def test_get_frechet_inception_distance(self):
# Mock `frechet_inception_distance` which is expensive.
with mock.patch.object(
util.tfgan.eval, 'frechet_inception_distance') as mock_fid:
mock_fid.return_value = 1.0
util.get_frechet_inception_distance(
tf.placeholder(tf.float32, shape=[None, 28, 28, 3]),
tf.placeholder(tf.float32, shape=[None, 28, 28, 3]),
batch_size=100,
num_inception_images=10)
# Mock `inception_score` which is expensive.
@mock.patch.object(util.tfgan.eval, 'inception_score', autospec=True)
def test_get_inception_scores(self, mock_inception_score):
mock_inception_score.return_value = 1.0
util.get_inception_scores(
tf.placeholder(tf.float32, shape=[None, 28, 28, 3]),
batch_size=100,
num_inception_images=10)

# Mock `frechet_inception_distance` which is expensive.
@mock.patch.object(util.tfgan.eval, 'frechet_inception_distance',
autospec=True)
def test_get_frechet_inception_distance(self, mock_fid):
mock_fid.return_value = 1.0
util.get_frechet_inception_distance(
tf.placeholder(tf.float32, shape=[None, 28, 28, 3]),
tf.placeholder(tf.float32, shape=[None, 28, 28, 3]),
batch_size=100,
num_inception_images=10)


if __name__ == '__main__':
Expand Down
93 changes: 93 additions & 0 deletions research/gan/image_compression/data_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains code for loading and preprocessing the compression image data."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function


import tensorflow as tf

from slim.datasets import dataset_factory as datasets

slim = tf.contrib.slim


def provide_data(split_name, batch_size, dataset_dir,
dataset_name='imagenet', num_readers=1, num_threads=1,
patch_size=128):
"""Provides batches of image data for compression.
Args:
split_name: Either 'train' or 'validation'.
batch_size: The number of images in each batch.
dataset_dir: The directory where the data can be found. If `None`, use
default.
dataset_name: Name of the dataset.
num_readers: Number of dataset readers.
num_threads: Number of prefetching threads.
patch_size: Size of the path to extract from the image.
Returns:
images: A `Tensor` of size [batch_size, patch_size, patch_size, channels]
"""
randomize = split_name == 'train'
dataset = datasets.get_dataset(
dataset_name, split_name, dataset_dir=dataset_dir)
provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
num_readers=num_readers,
common_queue_capacity=5 * batch_size,
common_queue_min=batch_size,
shuffle=randomize)
[image] = provider.get(['image'])

# Sample a patch of fixed size.
patch = tf.image.resize_image_with_crop_or_pad(image, patch_size, patch_size)
patch.shape.assert_is_compatible_with([patch_size, patch_size, 3])

# Preprocess the images. Make the range lie in a strictly smaller range than
# [-1, 1], so that network outputs aren't forced to the extreme ranges.
patch = (tf.to_float(patch) - 128.0) / 142.0

if randomize:
image_batch = tf.train.shuffle_batch(
[patch],
batch_size=batch_size,
num_threads=num_threads,
capacity=5 * batch_size,
min_after_dequeue=batch_size)
else:
image_batch = tf.train.batch(
[patch],
batch_size=batch_size,
num_threads=1, # no threads so it's deterministic
capacity=5 * batch_size)

return image_batch


def float_image_to_uint8(image):
"""Convert float image in ~[-0.9, 0.9) to [0, 255] uint8.
Args:
image: An image tensor. Values should be in [-0.9, 0.9).
Returns:
Input image cast to uint8 and with integer values in [0, 255].
"""
image = (image * 142.0) + 128.0
return tf.cast(image, tf.uint8)
60 changes: 60 additions & 0 deletions research/gan/image_compression/data_provider_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for data_provider."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

import numpy as np

import tensorflow as tf

import data_provider


class DataProviderTest(tf.test.TestCase):

def _test_data_provider_helper(self, split_name):
dataset_dir = os.path.join(
tf.flags.FLAGS.test_srcdir,
'google3/third_party/tensorflow_models/gan/image_compression/testdata/')

batch_size = 3
patch_size = 8
images = data_provider.provide_data(
split_name, batch_size, dataset_dir, patch_size=8)
self.assertListEqual([batch_size, patch_size, patch_size, 3],
images.shape.as_list())

with self.test_session(use_gpu=True) as sess:
with tf.contrib.slim.queues.QueueRunners(sess):
images_out = sess.run(images)
self.assertEqual((batch_size, patch_size, patch_size, 3),
images_out.shape)
# Check range.
self.assertTrue(np.all(np.abs(images_out) <= 1.0))

def test_data_provider_train(self):
self._test_data_provider_helper('train')

def test_data_provider_validation(self):
self._test_data_provider_helper('validation')


if __name__ == '__main__':
tf.test.main()
101 changes: 101 additions & 0 deletions research/gan/image_compression/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluates a TFGAN trained compression model."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function



import tensorflow as tf

import data_provider
import networks
import summaries

flags = tf.flags
FLAGS = flags.FLAGS

flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.')

flags.DEFINE_string('checkpoint_dir', '/tmp/compression/',
'Directory where the model was written to.')

flags.DEFINE_string('eval_dir', '/tmp/compression/',
'Directory where the results are saved to.')

flags.DEFINE_integer('max_number_of_evaluations', None,
'Number of times to run evaluation. If `None`, run '
'forever.')

flags.DEFINE_string('dataset_dir', None, 'Location of data.')

# Compression-specific flags.
flags.DEFINE_integer('batch_size', 32, 'The number of images in each batch.')

flags.DEFINE_integer('patch_size', 32, 'The size of the patches to train on.')

flags.DEFINE_integer('bits_per_patch', 1230,
'The number of bits to produce per patch.')

flags.DEFINE_integer('model_depth', 64,
'Number of filters for compression model')


def main(_, run_eval_loop=True):
with tf.name_scope('inputs'):
images = data_provider.provide_data(
'validation', FLAGS.batch_size, dataset_dir=FLAGS.dataset_dir,
patch_size=FLAGS.patch_size)

# In order for variables to load, use the same variable scope as in the
# train job.
with tf.variable_scope('generator'):
reconstructions, _, prebinary = networks.compression_model(
images,
num_bits=FLAGS.bits_per_patch,
depth=FLAGS.model_depth,
is_training=False)
summaries.add_reconstruction_summaries(images, reconstructions, prebinary)

# Visualize losses.
pixel_loss_per_example = tf.reduce_mean(
tf.abs(images - reconstructions), axis=[1, 2, 3])
pixel_loss = tf.reduce_mean(pixel_loss_per_example)
tf.summary.histogram('pixel_l1_loss_hist', pixel_loss_per_example)
tf.summary.scalar('pixel_l1_loss', pixel_loss)

# Create ops to write images to disk.
uint8_images = data_provider.float_image_to_uint8(images)
uint8_reconstructions = data_provider.float_image_to_uint8(reconstructions)
uint8_reshaped = summaries.stack_images(uint8_images, uint8_reconstructions)
image_write_ops = tf.write_file(
'%s/%s'% (FLAGS.eval_dir, 'compression.png'),
tf.image.encode_png(uint8_reshaped[0]))

# For unit testing, use `run_eval_loop=False`.
if not run_eval_loop: return
tf.contrib.training.evaluate_repeatedly(
FLAGS.checkpoint_dir,
master=FLAGS.master,
hooks=[tf.contrib.training.SummaryAtEndHook(FLAGS.eval_dir),
tf.contrib.training.StopAfterNEvalsHook(1)],
eval_ops=image_write_ops,
max_number_of_evaluations=FLAGS.max_number_of_evaluations)


if __name__ == '__main__':
tf.app.run()
32 changes: 32 additions & 0 deletions research/gan/image_compression/eval_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for gan.image_compression.eval."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import eval # pylint:disable=redefined-builtin


class EvalTest(tf.test.TestCase):

def test_build_graph(self):
eval.main(None, run_eval_loop=False)


if __name__ == '__main__':
tf.test.main()
Loading

0 comments on commit b6907e8

Please sign in to comment.