Skip to content

Commit

Permalink
Add profiler callback for Keras models (tensorflow#6528)
Browse files Browse the repository at this point in the history
* Add profiler callback for Keras models

* Update build stats to identify time callback by type

* Add warning message when both TensorBoard and profiler callbacks are used
  • Loading branch information
haoyuz authored Apr 5, 2019
1 parent 7467ccd commit 3f94db4
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 44 deletions.
8 changes: 2 additions & 6 deletions official/resnet/keras/keras_cifar_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def run(flags_obj):
optimizer=optimizer,
metrics=['categorical_accuracy'])

time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
callbacks = keras_common.get_callbacks(
learning_rate_schedule, cifar_main.NUM_IMAGES['train'])

train_steps = cifar_main.NUM_IMAGES['train'] // flags_obj.batch_size
Expand All @@ -180,10 +180,6 @@ def run(flags_obj):
num_eval_steps = None
validation_data = None

callbacks = [time_callback, lr_callback]
if flags_obj.enable_tensorboard:
callbacks.append(tensorboard_callback)

history = model.fit(train_input_dataset,
epochs=train_epochs,
steps_per_epoch=train_steps,
Expand All @@ -197,7 +193,7 @@ def run(flags_obj):
eval_output = model.evaluate(eval_input_dataset,
steps=num_eval_steps,
verbose=2)
stats = keras_common.build_stats(history, eval_output, time_callback)
stats = keras_common.build_stats(history, eval_output, callbacks)
return stats


Expand Down
113 changes: 88 additions & 25 deletions official/resnet/keras/keras_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from official.utils.misc import keras_utils
# pylint: disable=ungrouped-imports
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.eager import profiler
from tensorflow.python.keras.optimizer_v2 import (gradient_descent as
gradient_descent_v2)

Expand Down Expand Up @@ -78,6 +79,29 @@ def on_batch_begin(self, batch, logs=None):
'change learning rate to %s.', self.epochs, batch, lr)


class ProfilerCallback(tf.keras.callbacks.Callback):
"""Save profiles in specified step range to log directory."""

def __init__(self, log_dir, start_step, stop_step):
super(ProfilerCallback, self).__init__()
self.log_dir = log_dir
self.start_step = start_step
self.stop_step = stop_step

def on_batch_begin(self, batch, logs=None):
if batch == self.start_step:
profiler.start()
tf.compat.v1.logging.info('Profiler started at Step %s', self.start_step)

def on_batch_end(self, batch, logs=None):
if batch == self.stop_step:
results = profiler.stop()
profiler.save(self.log_dir, results)
tf.compat.v1.logging.info(
'Profiler saved profiles for steps between %s and %s to %s',
self.start_step, self.stop_step, self.log_dir)


def get_config_proto_v1():
"""Return config proto according to flag settings, or None to use default."""
config = None
Expand Down Expand Up @@ -143,27 +167,59 @@ def get_optimizer():
def get_callbacks(learning_rate_schedule_fn, num_images):
"""Returns common callbacks."""
time_callback = keras_utils.TimeHistory(FLAGS.batch_size, FLAGS.log_steps)

tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=FLAGS.model_dir)

lr_callback = LearningRateBatchScheduler(
learning_rate_schedule_fn,
batch_size=FLAGS.batch_size,
num_images=num_images)

return time_callback, tensorboard_callback, lr_callback


def build_stats(history, eval_output, time_callback):
callbacks = [time_callback, lr_callback]

if FLAGS.enable_tensorboard:
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=FLAGS.model_dir)
callbacks.append(tensorboard_callback)

if FLAGS.profile_steps:
profiler_callback = get_profiler_callback()
callbacks.append(profiler_callback)

return callbacks


def get_profiler_callback():
"""Validate profile_steps flag value and return profiler callback."""
profile_steps_error_message = (
'profile_steps must be a comma separated pair of positive integers, '
'specifying the first and last steps to be profiled.'
)
try:
profile_steps = [int(i) for i in FLAGS.profile_steps.split(',')]
except ValueError:
raise ValueError(profile_steps_error_message)
if len(profile_steps) != 2:
raise ValueError(profile_steps_error_message)
start_step, stop_step = profile_steps
if start_step < 0 or start_step > stop_step:
raise ValueError(profile_steps_error_message)
if FLAGS.enable_tensorboard:
tf.compat.v1.logging.warn(
'Both TensorBoard and profiler callbacks are used. Note that the '
'TensorBoard callback profiles the 2nd step (unless otherwise '
'specified). Please make sure the steps profiled by the two callbacks '
'do not overlap.')

return ProfilerCallback(FLAGS.model_dir, start_step, stop_step)


def build_stats(history, eval_output, callbacks):
"""Normalizes and returns dictionary of stats.
Args:
history: Results of the training step. Supports both categorical_accuracy
and sparse_categorical_accuracy.
eval_output: Output of the eval step. Assumes first value is eval_loss and
second value is accuracy_top_1.
time_callback: Time tracking callback likely used during keras.fit.
callbacks: a list of callbacks which might include a time history callback
used during keras.fit.
Returns:
Dictionary of normalized results.
Expand All @@ -183,16 +239,20 @@ def build_stats(history, eval_output, time_callback):
elif 'sparse_categorical_accuracy' in train_hist:
stats[TRAIN_TOP_1] = train_hist['sparse_categorical_accuracy'][-1].item()

if time_callback:
timestamp_log = time_callback.timestamp_log
stats['step_timestamp_log'] = timestamp_log
stats['train_finish_time'] = time_callback.train_finish_time
if len(timestamp_log) > 1:
stats['avg_exp_per_second'] = (
time_callback.batch_size * time_callback.log_steps *
(len(time_callback.timestamp_log)-1) /
(timestamp_log[-1].timestamp - timestamp_log[0].timestamp))

if not callbacks:
return stats

# Look for the time history callback which was used during keras.fit
for callback in callbacks:
if isinstance(callback, keras_utils.TimeHistory):
timestamp_log = callback.timestamp_log
stats['step_timestamp_log'] = timestamp_log
stats['train_finish_time'] = callback.train_finish_time
if len(timestamp_log) > 1:
stats['avg_exp_per_second'] = (
callback.batch_size * callback.log_steps *
(len(callback.timestamp_log)-1) /
(timestamp_log[-1].timestamp - timestamp_log[0].timestamp))
return stats


Expand All @@ -215,11 +275,14 @@ def define_keras_flags():
help='The number of steps to run for training. If it is larger than '
'# batches per epoch, then use # batches per epoch. When this flag is '
'set, only one epoch is going to run for training.')
flags.DEFINE_boolean(
name='enable_e2e_xprof', default=False,
help='Save end-to-end profiling data to model dir using Xprof. Profiling '
'has an overhead on both computation and memory usage, and can generate '
'gigantic files when profiling a lot of steps.')
flags.DEFINE_string(
name='profile_steps', default=None,
help='Save profiling data to model dir at given range of steps. The '
'value must be a comma separated pair of positive integers, specifying '
'the first and last step to profile. For example, "--profile_steps=2,4" '
'triggers the profiler to process 3 steps, starting from the 2nd step. '
'Note that profiler has a non-trivial performance overhead, and the '
'output file can be gigantic if profiling many steps.')


def get_synth_input_fn(height, width, num_channels, num_classes,
Expand Down
15 changes: 2 additions & 13 deletions official/resnet/keras/keras_imagenet_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order

from tensorflow.python.eager import profiler
from official.resnet import imagenet_main
from official.resnet.keras import keras_common
from official.resnet.keras import resnet_model
Expand Down Expand Up @@ -177,7 +176,7 @@ def run(flags_obj):
optimizer=optimizer,
metrics=['sparse_categorical_accuracy'])

time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
callbacks = keras_common.get_callbacks(
learning_rate_schedule, imagenet_main.NUM_IMAGES['train'])

train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size
Expand All @@ -199,12 +198,6 @@ def run(flags_obj):
num_eval_steps = None
validation_data = None

callbacks = [time_callback, lr_callback]
if flags_obj.enable_tensorboard:
callbacks.append(tensorboard_callback)
if flags_obj.enable_e2e_xprof:
profiler.start()

history = model.fit(train_input_dataset,
epochs=train_epochs,
steps_per_epoch=train_steps,
Expand All @@ -214,16 +207,12 @@ def run(flags_obj):
validation_freq=flags_obj.epochs_between_evals,
verbose=2)

if flags_obj.enable_e2e_xprof:
results = profiler.stop()
profiler.save(flags_obj.model_dir, results)

eval_output = None
if not flags_obj.skip_eval:
eval_output = model.evaluate(eval_input_dataset,
steps=num_eval_steps,
verbose=2)
stats = keras_common.build_stats(history, eval_output, time_callback)
stats = keras_common.build_stats(history, eval_output, callbacks)
return stats


Expand Down

0 comments on commit 3f94db4

Please sign in to comment.