Skip to content

Commit

Permalink
Log Keras training and validation summaries to separate runs.
Browse files Browse the repository at this point in the history
Previously, metrics like accuracy would be logged as `epoch_acc` during
training and `val_epoch_acc` during training. As of this change,
training metrics are logged to a `train/` subdirectory under the
top-level logdir, while validation metrics are logged to `validation/`.
This has the advantage that training and validation metrics can be shown
in the same plot:

![Screenshot of new behavior](https://user-images.githubusercontent.com/4317806/52606214-9e3ddc80-2e26-11e9-9a02-2a5228edc8f6.png)

Tested:
Running a simple MNIST model generates a TensorBoard instance with
summaries as described above, and with a graph under the ?train? run.

RELNOTES: Keras training and validation curves are shown on the same plot.
PiperOrigin-RevId: 234179031
  • Loading branch information
wchargin authored and tensorflower-gardener committed Feb 15, 2019
1 parent dfd0925 commit dd2d989
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 85 deletions.
78 changes: 63 additions & 15 deletions tensorflow/python/keras/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,6 +1162,10 @@ def __init__(self,
self._total_batches_seen = 0
self._total_val_batches_seen = 0

self._writers = [] # file writers to be closed
self._train_writer = None # set in `_initialize_writers`
self._validation_writer = None # set in `_initialize_writers`

def _validate_kwargs(self, kwargs):
"""Handle arguments were supported in V1."""
if kwargs.get('write_grads', False):
Expand All @@ -1185,16 +1189,44 @@ def set_model(self, model):
"""Sets Keras model and writes graph if specified."""
self.model = model
with context.eager_mode():
self.writer = summary_ops_v2.create_file_writer(self.log_dir)
self._initialize_writers()
if self.write_graph:
if model.run_eagerly:
logging.warning('TensorBoard Callback will ignore `write_graph=True`'
'when `Model.run_eagerly=True`.`')
else:
with self.writer.as_default():
with self._train_writer.as_default():
with summary_ops_v2.always_record_summaries():
summary_ops_v2.graph(K.get_graph())

def _close_writers(self):
"""Close all remaining open file writers owned by this callback.
If there are no such file writers, this is a no-op.
"""
with context.eager_mode():
for writer in self._writers:
writer.close()
del self._writers[:]

def _initialize_writers(self):
"""Create all file writers needed and validation writers.
This updates `self._train_writer` and `self._validation_writer`, and
populates the `self._writers` list to be cleaned up by
`_close_writers`.
"""
self._close_writers()

def create_writer(subdir):
path = os.path.join(self.log_dir, subdir)
return summary_ops_v2.create_file_writer(path)

self._train_writer = create_writer('train')
self._writers.append(self._train_writer)
self._validation_writer = create_writer('validation')
self._writers.append(self._validation_writer)

def on_batch_end(self, batch, logs=None):
"""Writes scalar summaries for metrics on every training batch."""
# Don't output batch_size and batch number as TensorBoard summaries
Expand All @@ -1215,8 +1247,7 @@ def on_epoch_end(self, epoch, logs=None):
self._log_weights(epoch)

def on_train_end(self, logs=None):
with context.eager_mode():
self.writer.close()
self._close_writers()

def _log_metrics(self, logs, prefix, step):
"""Writes metrics out as custom scalar summaries.
Expand All @@ -1228,20 +1259,37 @@ def _log_metrics(self, logs, prefix, step):
"""
if logs is None:
logs = {}
# Scrub non-metric items and assign batch or epoch prefix.
metric_logs = {(prefix + k): v
for k, v in logs.items()
if k not in ['batch', 'size', 'num_steps']}
with context.eager_mode(), \
self.writer.as_default(), \
summary_ops_v2.always_record_summaries():
for name, value in metric_logs.items():
summary_ops_v2.scalar(name, value, step=step)

# Group metrics by their associated file writer. Values are lists of
# metrics, as (name, scalar_value) pairs.
logs_by_writer = {
self._train_writer: [],
self._validation_writer: [],
}
validation_prefix = 'val_'
for (name, value) in logs.items():
if name in ('batch', 'size', 'num_steps'):
# Scrub non-metric items.
continue
if name.startswith(validation_prefix):
name = name[len(validation_prefix):]
writer = self._validation_writer
else:
writer = self._train_writer
name = prefix + name # assign batch or epoch prefix
logs_by_writer[writer].append((name, value))

with context.eager_mode():
with summary_ops_v2.always_record_summaries():
for writer in logs_by_writer:
with writer.as_default():
for (name, value) in logs_by_writer[writer]:
summary_ops_v2.scalar(name, value, step=step)

def _log_weights(self, epoch):
"""Logs the weights of the Model to TensorBoard."""
with context.eager_mode(), \
self.writer.as_default(), \
self._train_writer.as_default(), \
summary_ops_v2.always_record_summaries():
for layer in self.model.layers:
for weight in layer.weights:
Expand All @@ -1251,7 +1299,7 @@ def _log_weights(self, epoch):
summary_ops_v2.histogram(weight_name, weight, step=epoch)
if self.write_images:
self._log_weight_as_image(weight, weight_name, epoch)
self.writer.flush()
self._train_writer.flush()

def _log_weight_as_image(self, weight, weight_name, epoch):
"""Logs a weight as a TensorBoard image."""
Expand Down
Loading

0 comments on commit dd2d989

Please sign in to comment.