Skip to content

Commit

Permalink
readme source url and import optimization (tensorflow#7084)
Browse files Browse the repository at this point in the history
* restored missing function

* missing import

* missing imports

* updated tutorial link

* recovered _print_download_progress func

* change default train with float16 instead of float32 accuracy

* test disable func call

* redundant function call, currently data is pulled automatically in

* optimized imports

* optimized imports
  • Loading branch information
GeorgeK-zn authored and tfboyd committed Jun 22, 2019
1 parent 47a5902 commit adc2717
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 8 deletions.
2 changes: 1 addition & 1 deletion tutorials/image/cifar10/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ Code in this directory demonstrates how to use TensorFlow to train and evaluate

Detailed instructions on how to get started available at:

http://tensorflow.org/tutorials/deep_cnn/
https://www.tensorflow.org/tutorials/images/deep_cnn
6 changes: 2 additions & 4 deletions tutorials/image/cifar10/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
# Basic model parameters.
tf.app.flags.DEFINE_integer('batch_size', 128,
"""Number of images to process in a batch.""")
tf.app.flags.DEFINE_boolean('use_fp16', False,
tf.app.flags.DEFINE_boolean('use_fp16', True,
"""Train the model using fp16.""")

# Global constants describing the CIFAR-10 data set.
Expand Down Expand Up @@ -146,16 +146,14 @@ def distorted_inputs():

def inputs(eval_data):
"""Construct input for CIFAR evaluation using the Reader ops.
Args:
eval_data: bool, indicating if one should use the train or eval data set.
Returns:
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
labels: Labels. 1D tensor of [batch_size] size.
"""
images, labels = cifar10_input.inputs(eval_data=eval_data,
batch_size=FLAGS.batch_size)
images, labels = cifar10_input.inputs(eval_data=eval_data, batch_size=FLAGS.batch_size)
if FLAGS.use_fp16:
images = tf.cast(images, tf.float16)
labels = tf.cast(labels, tf.float16)
Expand Down
6 changes: 3 additions & 3 deletions tutorials/image/cifar10/cifar10_multi_gpu_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,15 @@
from __future__ import division
from __future__ import print_function

from datetime import datetime
import os.path
import re
import time
from datetime import datetime

import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from six.moves import xrange # pylint: disable=redefined-builtin

import cifar10

FLAGS = tf.app.flags.FLAGS
Expand Down Expand Up @@ -266,7 +267,6 @@ def train():


def main(argv=None): # pylint: disable=unused-argument
cifar10.maybe_download_and_extract()
if tf.gfile.Exists(FLAGS.train_dir):
tf.gfile.DeleteRecursively(FLAGS.train_dir)
tf.gfile.MakeDirs(FLAGS.train_dir)
Expand Down

0 comments on commit adc2717

Please sign in to comment.