TensorFlow boilerplate code using the tf.data
and the tf.train.MonitoredTrainingSession
to build flexible and efficient input pipelines with simplified training
in a distributed setting.
The modular structure allows you to replace the used network/model or dataset with a single argument, and therefore makes it easy to compare various models, datasets and parameter settings.
The current version requires in particular the following libraries / versions.
To run a simple RCNN on Fashion MNIST use the following command (which is the default)
python3 tfbp.py --dataset data.fashionmnist --model model.rcnn
which produces the following output:
$ python3 tfbp.py
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 1s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
8192/5148 [===============================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 0s 0us/step
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tf-boilerplate-log/model.ckpt.
INFO:tensorflow:accuracy = 0.1015625, step = 0
INFO:tensorflow:accuracy = 0.8359375, step = 50 (0.896 sec)
INFO:tensorflow:accuracy = 0.83203125, step = 100 (0.344 sec)
INFO:tensorflow:accuracy = 0.859375, step = 150 (0.425 sec)
INFO:tensorflow:accuracy = 0.8828125, step = 200 (0.362 sec)
INFO:tensorflow:accuracy = 0.84375, step = 250 (0.379 sec)
INFO:tensorflow:accuracy = 0.87109375, step = 300 (0.351 sec)
INFO:tensorflow:accuracy = 0.86328125, step = 350 (0.320 sec)
INFO:tensorflow:accuracy = 0.87109375, step = 400 (0.342 sec)
INFO:tensorflow:accuracy = 0.86328125, step = 450 (0.384 sec)
INFO:tensorflow:accuracy = 0.9140625, step = 500 (0.361 sec)
INFO:tensorflow:accuracy = 0.90625, step = 550 (0.449 sec)
INFO:tensorflow:accuracy = 0.89453125, step = 600 (0.476 sec)
INFO:tensorflow:accuracy = 0.91015625, step = 650 (0.320 sec)
INFO:tensorflow:accuracy = 0.89453125, step = 700 (0.383 sec)
INFO:tensorflow:Saving checkpoints for 705 into /tmp/tf-boilerplate-log/model.ckpt.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tf-boilerplate-log/model.ckpt-705
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:accuracy = 0.8886719
INFO:tensorflow:accuracy = 0.84765625 (0.018 sec)
INFO:tensorflow:accuracy = 0.8769531 (0.017 sec)
INFO:tensorflow:accuracy = 0.8769531 (0.017 sec)
Final Mean Accuracy: 0.8743681
Run python3 tfbp.py --help
to see a complete list of command line arguments.
Currently we provide tf.data wrappers for MNIST, Fashion MNIST and CIFAR10, feel free to contribute others as well!
The CNN model is simply for educational purpose.
Here is a short introduction to the used TensorFlow APIs.
For more information see the references section.
The code is structured very modular, all models and datasets are dynamically
imported as modules, given the --dataset
and --model
Then, tfbp.py
runs a basic training loop using the training dataset
to evaluate the lossfn
, and minimizes the loss using the AdamOptimizer
At the end, the model is evaluated using the testing dataset.
To build flexible and efficient input pipelines we make use of the tf.data
We introduce a simple DataSampler
class, which has the abstract methods training()
, testing()
and validation()
These methods must be implemented for each new dataset, and will be used during the training loop.
Any type of tf.data.Dataset
can be returned, but e.g. batch_size
and epochs
will be set during a later stage by tfbp.py
For an overview to the tf.data API see the Importing Data Guide.
See the MNIST example.
The tf.data API works well with the tf.train API for distributed execution, especially tf.train.MonitoredTrainingSession
The class MonitoredSession
provides a tf.Session
-like object that handles initialization, recovery and hooks.
For distributed settings, use tf.train.MonitoredTrainingSession
if not, tf.train.SingularMonitoredSession
is recommended.
For now, we use the class SingularMonitoredSession
, as it provides all the goodies we need for the tf.data API.
If needed, the SingularMonitoredSession
can be replaced with MonitoredSession
Here is a basic example:
# define a dataset
dataset = tf.data.Dataset(...).batch(32).repeat(5)
data = dataset.make_one_shot_iterator().get_next()
# define model, loss and optimizer
loss = network(data)
train_op = tf.train.AdamOptimizer().minimize(loss)
# SingularMonitoredSession example
# checkpoints and summaries are saved periodically
saver_hook = CheckpointSaverHook(...)
summary_hook = SummarySaverHook(...)
with SingularMonitoredSession(hooks=[saver_hook, summary_hook]) as sess:
while not sess.should_stop():
Parameters like batch_size
and epoch
are implicit set via the Dataset.
Various hooks can be used to evaluate / process tensors during training, see Training -> Training Hooks
For example
to log different tensors (e.g. current step, time or metrics)CheckpointSaverHook
to save the model parametersSummarySaverHook
to save summariesOneTimeSummarySaverHook
to save summaries exactly once. (This can come handy to save the parameters of your run inside of your log and, thus, can be checked after training directly in tensorboard).
Logging the current step and accuracy, the command line output will look like (from the example above)
INFO:tensorflow:step = 70, accuracy = 0.90625 (0.006 sec)
For an overview see Importing Data Guide - Using high-level APIs.
Here is a (probably incomplete) list of resources, please contribute!
- tf.data: Fast, flexible, and easy-to-use input pipelines (TensorFlow Dev Summit 2018) (youtu.be)
- Google TF Datasets Intro
- TF Importing Data
- TF tf.data.Dataset docs
- Distributed TensorFlow (TensorFlow Dev Summit 2018) (youtu.be)
- Distributed TensorFlow training (Google I/O '18) (youtu.be)
- TF tf.contrib.distribute (Docs)
- TF tf.contrib.distribute.DistributionStrategy (Docs)
- Distributed TensorFlow (Docs)
- Training Performance: A user’s guide to converge faster (TensorFlow Dev Summit 2018) (youtu.be)
- ResNet Model (TF Models Repo)
- TF Performance Benchmarks (Docs)
- TF Benchmarks (Repo)
- TF Testing Benchmarks (Website)
- @tfboyd/tf-tools
- @tfboyd/benchmark_harness
Please read CONTRIBUTING.md for details on our code of conduct, and the process for submitting pull requests to us.
Feel free to open a PR or issue, we are happy to help!
We use SemVer for versioning. For the versions available, see the tags on this repository.
This project is licensed under the MIT License - see the LICENSE.md file for details
The dataset-handlers are taken from the tensorflow/models repository.