From dd2d989bb316f59a36e3e47973814255866ba6c4 Mon Sep 17 00:00:00 2001 From: William Chargin Date: Fri, 15 Feb 2019 11:16:39 -0800 Subject: [PATCH] Log Keras training and validation summaries to separate runs. Previously, metrics like accuracy would be logged as `epoch_acc` during training and `val_epoch_acc` during training. As of this change, training metrics are logged to a `train/` subdirectory under the top-level logdir, while validation metrics are logged to `validation/`. This has the advantage that training and validation metrics can be shown in the same plot: ![Screenshot of new behavior](https://user-images.githubusercontent.com/4317806/52606214-9e3ddc80-2e26-11e9-9a02-2a5228edc8f6.png) Tested: Running a simple MNIST model generates a TensorBoard instance with summaries as described above, and with a graph under the ?train? run. RELNOTES: Keras training and validation curves are shown on the same plot. PiperOrigin-RevId: 234179031 --- tensorflow/python/keras/callbacks.py | 78 +++++++-- tensorflow/python/keras/callbacks_test.py | 193 ++++++++++++++-------- 2 files changed, 186 insertions(+), 85 deletions(-) diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 70119324ea3496..011124822509f1 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -1162,6 +1162,10 @@ def __init__(self, self._total_batches_seen = 0 self._total_val_batches_seen = 0 + self._writers = [] # file writers to be closed + self._train_writer = None # set in `_initialize_writers` + self._validation_writer = None # set in `_initialize_writers` + def _validate_kwargs(self, kwargs): """Handle arguments were supported in V1.""" if kwargs.get('write_grads', False): @@ -1185,16 +1189,44 @@ def set_model(self, model): """Sets Keras model and writes graph if specified.""" self.model = model with context.eager_mode(): - self.writer = summary_ops_v2.create_file_writer(self.log_dir) + self._initialize_writers() if self.write_graph: if model.run_eagerly: logging.warning('TensorBoard Callback will ignore `write_graph=True`' 'when `Model.run_eagerly=True`.`') else: - with self.writer.as_default(): + with self._train_writer.as_default(): with summary_ops_v2.always_record_summaries(): summary_ops_v2.graph(K.get_graph()) + def _close_writers(self): + """Close all remaining open file writers owned by this callback. + + If there are no such file writers, this is a no-op. + """ + with context.eager_mode(): + for writer in self._writers: + writer.close() + del self._writers[:] + + def _initialize_writers(self): + """Create all file writers needed and validation writers. + + This updates `self._train_writer` and `self._validation_writer`, and + populates the `self._writers` list to be cleaned up by + `_close_writers`. + """ + self._close_writers() + + def create_writer(subdir): + path = os.path.join(self.log_dir, subdir) + return summary_ops_v2.create_file_writer(path) + + self._train_writer = create_writer('train') + self._writers.append(self._train_writer) + self._validation_writer = create_writer('validation') + self._writers.append(self._validation_writer) + def on_batch_end(self, batch, logs=None): """Writes scalar summaries for metrics on every training batch.""" # Don't output batch_size and batch number as TensorBoard summaries @@ -1215,8 +1247,7 @@ def on_epoch_end(self, epoch, logs=None): self._log_weights(epoch) def on_train_end(self, logs=None): - with context.eager_mode(): - self.writer.close() + self._close_writers() def _log_metrics(self, logs, prefix, step): """Writes metrics out as custom scalar summaries. @@ -1228,20 +1259,37 @@ def _log_metrics(self, logs, prefix, step): """ if logs is None: logs = {} - # Scrub non-metric items and assign batch or epoch prefix. - metric_logs = {(prefix + k): v - for k, v in logs.items() - if k not in ['batch', 'size', 'num_steps']} - with context.eager_mode(), \ - self.writer.as_default(), \ - summary_ops_v2.always_record_summaries(): - for name, value in metric_logs.items(): - summary_ops_v2.scalar(name, value, step=step) + + # Group metrics by their associated file writer. Values are lists of + # metrics, as (name, scalar_value) pairs. + logs_by_writer = { + self._train_writer: [], + self._validation_writer: [], + } + validation_prefix = 'val_' + for (name, value) in logs.items(): + if name in ('batch', 'size', 'num_steps'): + # Scrub non-metric items. + continue + if name.startswith(validation_prefix): + name = name[len(validation_prefix):] + writer = self._validation_writer + else: + writer = self._train_writer + name = prefix + name # assign batch or epoch prefix + logs_by_writer[writer].append((name, value)) + + with context.eager_mode(): + with summary_ops_v2.always_record_summaries(): + for writer in logs_by_writer: + with writer.as_default(): + for (name, value) in logs_by_writer[writer]: + summary_ops_v2.scalar(name, value, step=step) def _log_weights(self, epoch): """Logs the weights of the Model to TensorBoard.""" with context.eager_mode(), \ - self.writer.as_default(), \ + self._train_writer.as_default(), \ summary_ops_v2.always_record_summaries(): for layer in self.model.layers: for weight in layer.weights: @@ -1251,7 +1299,7 @@ def _log_weights(self, epoch): summary_ops_v2.histogram(weight_name, weight, step=epoch) if self.write_images: self._log_weight_as_image(weight, weight_name, epoch) - self.writer.flush() + self._train_writer.flush() def _log_weight_as_image(self, weight, weight_name, epoch): """Logs a weight as a TensorBoard image.""" diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index 4863e5ceac787c..1bd24aa19d0cf2 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -32,6 +32,7 @@ from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context from tensorflow.python.framework import random_seed from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils @@ -966,57 +967,80 @@ def test_RemoteMonitorWithJsonPayload(self): epochs=1) -class _MockSummaryFile(object): - """Mocks a TensorBoard summary file, recording the tag names it sees.""" - - def __init__(self): - self.scalar_names = set() - self.hist_names = set() - self.image_names = set() - - -def _make_mock_scalar_summary(summary_file): - - def _mock_scalar_summary(name, *args, **kwargs): # pylint: disable=unused-argument - summary_file.scalar_names.update({name}) +# A summary that was emitted during a test. Fields: +# logdir: str. The logdir of the FileWriter to which the summary was +# written. +# tag: str. The name of the summary. +_ObservedSummary = collections.namedtuple('_ObservedSummary', ('logdir', 'tag')) - return _mock_scalar_summary +class _MockSummaryFile(object): + """Record summary tag names and the files to which they're written. -def _make_mock_hist_summary(summary_file): - - def _mock_hist_summary(name, *args, **kwargs): # pylint: disable=unused-argument - summary_file.hist_names.update({name}) - - return _mock_hist_summary - - -def _make_mock_image_summary(summary_file): - - def _mock_image_summary(name, *args, **kwargs): # pylint: disable=unused-argument - summary_file.image_names.update({name}) + Fields `scalars`, `images`, and `histograms` are sets containing + `_ObservedSummary` values. + """ - return _mock_image_summary + def __init__(self): + self.scalars = set() + self.images = set() + self.histograms = set() @tf_contextlib.contextmanager -def _mock_summary_api(summary_file): +def _mock_summary_api(): + summary_file = _MockSummaryFile() + + # Keep track of the logdir associated with each created resource. + # (There doesn't seem to be an easy way to get this information after + # the fact.) + resource_logdirs = {} + real_create_file_writer = summary_ops_v2.create_file_writer + + def mock_create_file_writer(logdir, *args, **kwargs): + writer = real_create_file_writer(logdir, *args, **kwargs) + resource = writer._resource + assert resource is not None + assert resource not in resource_logdirs, (resource, resource_logdirs) + resource_logdirs[resource] = logdir + return writer + + def make_mock_summary(summary_set): + + def mock_summary(tag, *args, **kwargs): + del args # unused + del kwargs # unused + resource = context.context().summary_writer_resource + logdir = resource_logdirs[resource] + summary_set.add(_ObservedSummary(logdir=logdir, tag=tag)) + + return mock_summary + with test.mock.patch.object(summary_ops_v2, - 'scalar', - _make_mock_scalar_summary(summary_file)), \ + 'create_file_writer', + mock_create_file_writer), \ + test.mock.patch.object(summary_ops_v2, + 'scalar', + make_mock_summary(summary_file.scalars)), \ test.mock.patch.object(summary_ops_v2, 'histogram', - _make_mock_hist_summary(summary_file)), \ + make_mock_summary(summary_file.histograms)), \ test.mock.patch.object(summary_ops_v2, 'image', - _make_mock_image_summary(summary_file)): - yield + make_mock_summary(summary_file.images)): + yield summary_file @keras_parameterized.run_with_all_model_types @keras_parameterized.run_all_keras_modes(always_skip_v1=True) class TestTensorBoardV2(keras_parameterized.TestCase): + def setUp(self): + super(TestTensorBoardV2, self).setUp() + self.logdir = os.path.join(self.get_temp_dir(), 'tb') + self.train_dir = os.path.join(self.logdir, 'train') + self.validation_dir = os.path.join(self.logdir, 'validation') + def _get_model(self): layers = [ keras.layers.Conv2D(8, (3, 3)), @@ -1028,13 +1052,11 @@ def _get_model(self): return model def test_TensorBoard_basic(self): - summary_file = _MockSummaryFile() model = self._get_model() x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) - temp_dir = self.get_temp_dir() + '/tb' - tb_cbk = keras.callbacks.TensorBoard(temp_dir) + tb_cbk = keras.callbacks.TensorBoard(self.logdir) - with _mock_summary_api(summary_file): # pylint: disable=not-context-manager + with _mock_summary_api() as summary_file: model.fit( x, y, @@ -1043,17 +1065,18 @@ def test_TensorBoard_basic(self): validation_data=(x, y), callbacks=[tb_cbk]) - self.assertEqual(summary_file.scalar_names, - {'epoch_loss', 'epoch_val_loss'}) + self.assertEqual( + summary_file.scalars, { + _ObservedSummary(logdir=self.train_dir, tag='epoch_loss'), + _ObservedSummary(logdir=self.validation_dir, tag='epoch_loss'), + }) def test_TensorBoard_batch_metrics(self): - summary_file = _MockSummaryFile() model = self._get_model() x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) - temp_dir = self.get_temp_dir() + '/tb' - tb_cbk = keras.callbacks.TensorBoard(temp_dir, update_freq=1) + tb_cbk = keras.callbacks.TensorBoard(self.logdir, update_freq=1) - with _mock_summary_api(summary_file): # pylint: disable=not-context-manager + with _mock_summary_api() as summary_file: model.fit( x, y, @@ -1062,17 +1085,22 @@ def test_TensorBoard_batch_metrics(self): validation_data=(x, y), callbacks=[tb_cbk]) - self.assertEqual(summary_file.scalar_names, - {'batch_loss', 'epoch_loss', 'epoch_val_loss'}) + self.assertEqual( + summary_file.scalars, + { + _ObservedSummary(logdir=self.train_dir, tag='batch_loss'), + _ObservedSummary(logdir=self.train_dir, tag='epoch_loss'), + _ObservedSummary(logdir=self.validation_dir, tag='epoch_loss'), + }, + ) def test_TensorBoard_weight_histograms(self): - summary_file = _MockSummaryFile() model = self._get_model() x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) temp_dir = self.get_temp_dir() + '/tb' tb_cbk = keras.callbacks.TensorBoard(temp_dir, histogram_freq=1) - with _mock_summary_api(summary_file): # pylint: disable=not-context-manager + with _mock_summary_api() as summary_file: model.fit( x, y, @@ -1081,24 +1109,29 @@ def test_TensorBoard_weight_histograms(self): validation_data=(x, y), callbacks=[tb_cbk]) - self.assertEqual(summary_file.scalar_names, - {'epoch_loss', 'epoch_val_loss'}) - - # Strip Layer names as Layers are created multiple times in test. - hist_names = { - name[name.rfind('/') + 1:] for name in summary_file.hist_names - } - self.assertEqual(hist_names, {'bias_0', 'kernel_0'}) + self.assertEqual( + summary_file.scalars, + { + _ObservedSummary(logdir=self.train_dir, tag='epoch_loss'), + _ObservedSummary(logdir=self.validation_dir, tag='epoch_loss'), + }, + ) + self.assertEqual( + self._strip_layer_names(summary_file.histograms), + { + _ObservedSummary(logdir=self.train_dir, tag='bias_0'), + _ObservedSummary(logdir=self.train_dir, tag='kernel_0'), + }, + ) def test_TensorBoard_weight_images(self): - summary_file = _MockSummaryFile() model = self._get_model() x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) temp_dir = self.get_temp_dir() + '/tb' tb_cbk = keras.callbacks.TensorBoard( temp_dir, histogram_freq=1, write_images=True) - with _mock_summary_api(summary_file): # pylint: disable=not-context-manager + with _mock_summary_api() as summary_file: model.fit( x, y, @@ -1107,19 +1140,39 @@ def test_TensorBoard_weight_images(self): validation_data=(x, y), callbacks=[tb_cbk]) - self.assertEqual(summary_file.scalar_names, - {'epoch_loss', 'epoch_val_loss'}) - - # Strip Layer names as Layers are created multiple times in test. - hist_names = { - name[name.rfind('/') + 1:] for name in summary_file.hist_names - } - self.assertEqual(hist_names, {'bias_0', 'kernel_0'}) - - image_names = { - name[name.rfind('/') + 1:] for name in summary_file.image_names - } - self.assertEqual(image_names, {'bias_0', 'kernel_0'}) + self.assertEqual( + summary_file.scalars, + { + _ObservedSummary(logdir=self.train_dir, tag='epoch_loss'), + _ObservedSummary(logdir=self.validation_dir, tag='epoch_loss'), + }, + ) + self.assertEqual( + self._strip_layer_names(summary_file.histograms), + { + _ObservedSummary(logdir=self.train_dir, tag='bias_0'), + _ObservedSummary(logdir=self.train_dir, tag='kernel_0'), + }, + ) + self.assertEqual( + self._strip_layer_names(summary_file.images), + { + _ObservedSummary(logdir=self.train_dir, tag='bias_0'), + _ObservedSummary(logdir=self.train_dir, tag='kernel_0'), + }, + ) + + def _strip_layer_names(self, summaries): + """Deduplicate summary names modulo layer suffix. + + Args: + summaries: A `set` of `_ObservedSummary` values. + + Returns: + A new `set` of `_ObservedSummary` values with layer suffixes + removed. + """ + return {s._replace(tag=s.tag[s.tag.rfind('/') + 1:]) for s in summaries} def test_TensorBoard_invalid_argument(self): with self.assertRaisesRegexp(ValueError, 'Unrecognized arguments'):