Skip to content

Commit

Permalink
internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 204367849
  • Loading branch information
cyfra committed Jul 12, 2018
1 parent 9af70fe commit 9d60281
Show file tree
Hide file tree
Showing 9 changed files with 432 additions and 22 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ paper (https://arxiv.org/abs/1711.10337) and in "The GAN Landscape: Losses, Arch

If you want to see the version used only in the first paper - please see the *v1* branch of this repository.

## Pre-trained models

The pre-trained models are available on TensorFlow Hub. Please see [this colab](https://colab.research.google.com/github/google/compare_gan/blob/master/compare_gan/src/tfhub_models.ipynb)
for an example how to use them.

### Best hyperparameters

Expand Down
2 changes: 2 additions & 0 deletions compare_gan/bin/compare_gan_generate_tasks
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

#!/usr/bin/env python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
Expand Down
2 changes: 2 additions & 0 deletions compare_gan/bin/compare_gan_run_one_task
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

#!/usr/bin/env python

from __future__ import absolute_import
from __future__ import division

Expand Down
13 changes: 1 addition & 12 deletions compare_gan/src/eval_gan_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,16 +398,6 @@ def RunCheckpointEval(checkpoint_path, task_workdir, options, inception_graph):
logging.info("Frechet Inception Distance for checkpoint %s is %.3f",
checkpoint_path, result_dict["fid_score"])

# For comparison with other papers, for CIFAR dataset compute the FID score
# also on the whole training set.
if dataset == "cifar10":
logging.info("Computing FID score on 50k train set.")
train_images = GetRealImages(dataset, "train", 50000)
result_dict["fid50k_score"] = ComputeTFGanFIDScore(
fake_images, train_images, inception_graph)
logging.info("Frechet Inception Distance on 50k for checkpoint %s is %.3f",
checkpoint_path, result_dict["fid50k_score"])

if ShouldRunMultiscaleSSIM(options):
result_dict["ms_ssim"] = ComputeMultiscaleSSIMScore(fake_images)
logging.info("MS-SSIM score computed: %.3f", result_dict["ms_ssim"])
Expand All @@ -420,7 +410,7 @@ def RunTaskEval(options, task_workdir, inception_graph, out_file="scores.csv"):
# If the output file doesn't exist, create it.
csv_header = [
"checkpoint_path", "model", "dataset", "tf_seed", "inception_score",
"fid_score", "fid50k_score", "ms_ssim_score", "train_accuracy",
"fid_score", "ms_ssim_score", "train_accuracy",
"test_accuracy", "fake_accuracy", "train_d_loss", "test_d_loss",
"sample_id"
]
Expand Down Expand Up @@ -464,7 +454,6 @@ def RunTaskEval(options, task_workdir, inception_graph, out_file="scores.csv"):
checkpoint_path, options["gan_type"], options["dataset"], tf_seed,
"%.3f" % result_dict.get("inception_score", default_value),
"%.3f" % result_dict.get("fid_score", default_value),
"%.3f" % result_dict.get("fid50k_score", default_value),
"%.3f" % result_dict.get("ms_ssim", default_value),
"%.3f" % result_dict.get("train_accuracy", default_value),
"%.3f" % result_dict.get("test_accuracy", default_value),
Expand Down
71 changes: 70 additions & 1 deletion compare_gan/src/gan_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,23 @@
from __future__ import division
from __future__ import print_function

import os
import random
import string

from absl import flags
from absl.testing import parameterized

from compare_gan.src import gan_lib
from compare_gan.src.gans import consts

import numpy as np
import tensorflow as tf

FLAGS = flags.FLAGS

class GanLibTest(tf.test.TestCase):

class GanLibTest(parameterized.TestCase, tf.test.TestCase):

def testLoadingTriangles(self):
with tf.Graph().as_default():
Expand Down Expand Up @@ -60,6 +71,64 @@ def testLoadingMnist(self):
self.assertEqual(image.shape, (28, 28, 1))
self.assertEqual(label.shape, ())

def trainSingleStep(self, tf_seed):
"""Train a GAN for a single training step and return the checkpoint."""
parameters = {
"tf_seed": tf_seed,
"learning_rate": 0.0002,
"z_dim": 64,
"batch_size": 2,
"training_steps": 1,
"disc_iters": 1,
"save_checkpoint_steps": 5000,
"discriminator_normalization": consts.NO_NORMALIZATION,
"dataset": "fake",
"gan_type": "GAN",
"penalty_type": consts.NO_PENALTY,
"architecture": consts.RESNET_CIFAR,
"lambda": 0.1,
}
random.seed(None)
exp_name = ''.join(random.choice(string.ascii_uppercase) for _ in range(16))
task_workdir = os.path.join(FLAGS.test_tmpdir, exp_name)
gan_lib.run_with_options(parameters, task_workdir)
ckpt_fn = os.path.join(task_workdir, "checkpoint/{}.model-0".format(
parameters["gan_type"]))
tf.logging.info("ckpt_fn: %s", ckpt_fn)
self.assertTrue(tf.gfile.Exists(ckpt_fn + ".index"))
return tf.train.load_checkpoint(ckpt_fn)

def testSameTFRandomSeed(self):
# Setting the same tf_seed should give the same initial values at each run.
# In practice training still converge due to non-deterministic behavior
# in certain operations on the hardware level (e.g. cuDNN optimizations).
ckpt_1 = self.trainSingleStep(tf_seed=42)
ckpt_2 = self.trainSingleStep(tf_seed=42)

for name in ckpt_1.get_variable_to_shape_map():
self.assertTrue(ckpt_2.has_tensor(name))
t1 = ckpt_1.get_tensor(name)
t2 = ckpt_2.get_tensor(name)
np.testing.assert_almost_equal(t1, t2)

@parameterized.named_parameters([
('Given', 1, 2),
('OneNone', 1, None),
('BothNone', None, None),
])
def testDifferentTFRandomSeed(self, seed_1, seed_2):
ckpt_1 = self.trainSingleStep(tf_seed=seed_1)
ckpt_2 = self.trainSingleStep(tf_seed=seed_2)

diff_counter = 0
for name in ckpt_1.get_variable_to_shape_map():
self.assertTrue(ckpt_2.has_tensor(name))
t1 = ckpt_1.get_tensor(name)
t2 = ckpt_2.get_tensor(name)
if np.abs(t1 - t2).sum() > 0:
diff_counter += 1
self.assertGreater(diff_counter, 0)


if __name__ == "__main__":
tf.test.main()
1 change: 0 additions & 1 deletion compare_gan/src/gans/resnet_architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

from math import log

# Dependency imports

from compare_gan.src.gans import consts
from compare_gan.src.gans import ops
Expand Down
11 changes: 7 additions & 4 deletions compare_gan/src/generate_tasks_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,9 +442,10 @@ def BestModelResnet19():
# Line 3: FID score: 102.74
{
"dataset": "lsun-bedroom",
"gan_type": consts.LSGAN_WITH_PENALTY,
"training_steps": 200000,
"penalty_type": consts.NO_PENALTY,
"learning_rate": 0.000322,
"learning_rate": 0.0000322,
"beta1": 0.5850,
"beta2": 0.9904,
"disc_iters": 1,
Expand All @@ -457,7 +458,7 @@ def BestModelResnet19():
"dataset": "lsun-bedroom",
"training_steps": 200000,
"penalty_type": consts.NO_PENALTY,
"learning_rate": 0.000193,
"learning_rate": 0.0000193,
"beta1": 0.1947,
"beta2": 0.8819,
"disc_iters": 1,
Expand Down Expand Up @@ -497,6 +498,7 @@ def BestModelResnet19():
# Line 7: FID score: 41.6
{
"dataset": "lsun-bedroom",
"gan_type": consts.LSGAN_WITH_PENALTY,
"training_steps": 200000,
"penalty_type": consts.NO_PENALTY,
"learning_rate": 0.0002,
Expand Down Expand Up @@ -580,11 +582,12 @@ def BestModelResnet19():
model.update({
"architecture": consts.RESNET5_ARCH,
"batch_size": 64,
"gan_type": consts.GAN_WITH_PENALTY,
"optimizer": "adam",
"save_checkpoint_steps": 20000,
"z_dim": 128,
})
if "gan_type" not in model:
model.update({"gan_type": consts.GAN_WITH_PENALTY})

return best_models

Expand Down Expand Up @@ -640,7 +643,7 @@ def BestModelResnetCifar():
"beta1": 0.5,
"beta2": 0.999,
"disc_iters": 5,
"discriminator_normalization": consts.NO_NORMALIZATION,
"discriminator_normalization": consts.SPECTRAL_NORM,
"tf_seed": 2,
"lambda": 1,
},
Expand Down
Loading

0 comments on commit 9d60281

Please sign in to comment.