Skip to content

Commit

Permalink
Some improvements and bug fixes to Controller:
Browse files Browse the repository at this point in the history
1. Fix a bug that checkpoint will be saved after every training loop.
2. Only create the training and eval summaries writers if the corresponding `train_fn` and `eval_fn` are passed.
3. Flush the summary writers after training and eval finish.
4. Add a Controller test.

Also make sure there is no evaluation happening in Resnet CTL example if `skip_eval=True`.

PiperOrigin-RevId: 301489305
  • Loading branch information
rxsang authored and tensorflower-gardener committed Mar 18, 2020
1 parent 101f1f0 commit febaae9
Show file tree
Hide file tree
Showing 4 changed files with 306 additions and 36 deletions.
69 changes: 37 additions & 32 deletions official/staging/training/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,10 @@ def __init__(
eval_summary_dir: The directory to write eval summaries. If None, it will
be set to `summary_dir`.
eval_steps: Number of steps to run evaluation.
eval_interval: Step interval for evaluation. If None, will skip
evaluation. Note that evaluation only happens outside the training loop,
which the loop iteration is specify by `steps_per_loop` parameter.
eval_interval: Step interval for evaluation. If None, will skip evaluation
in the middle of training. Note that evaluation only happens outside the
training loop, which the loop iteration is specify by `steps_per_loop`
parameter.
Raises:
ValueError: If both `train_fn` and `eval_fn` are None.
Expand Down Expand Up @@ -111,35 +112,41 @@ def __init__(
self.train_fn = train_fn
self.eval_fn = eval_fn
self.global_step = global_step

self.train_steps = train_steps

self.steps_per_loop = steps_per_loop

self.summary_dir = summary_dir or checkpoint_manager.directory
self.checkpoint_manager = checkpoint_manager

self.summary_interval = summary_interval
summary_writer = tf.summary.create_file_writer(
self.summary_dir) if self.summary_interval else None
# TODO(rxsang): Consider pass SummaryManager directly into Controller for
# maximum customizability.
self.summary_manager = utils.SummaryManager(
summary_writer,
tf.summary.scalar,
global_step=self.global_step,
summary_interval=self.summary_interval)
if self.train_fn is not None:
self.train_steps = train_steps
self.steps_per_loop = steps_per_loop
self.summary_dir = summary_dir or checkpoint_manager.directory

self.summary_interval = summary_interval
summary_writer = tf.summary.create_file_writer(
self.summary_dir) if self.summary_interval else None
# TODO(rxsang): Consider pass SummaryManager directly into Controller for
# maximum customizability.
self.summary_manager = utils.SummaryManager(
summary_writer,
tf.summary.scalar,
global_step=self.global_step,
summary_interval=self.summary_interval)

if self.eval_fn is not None:
eval_summary_dir = eval_summary_dir or self.summary_dir
eval_summary_writer = tf.summary.create_file_writer(
eval_summary_dir) if eval_summary_dir else None
self.eval_summary_manager = utils.SummaryManager(
eval_summary_writer, tf.summary.scalar, global_step=self.global_step)

self.eval_steps = eval_steps
self.eval_interval = eval_interval

# Create and initialize the interval triggers.
self.eval_trigger = utils.IntervalTrigger(self.eval_interval,
self.global_step.numpy())

if self.global_step:
tf.summary.experimental.set_step(self.global_step)

self.eval_summary_dir = eval_summary_dir or self.summary_dir
eval_summary_writer = tf.summary.create_file_writer(self.eval_summary_dir)
self.eval_summary_manager = utils.SummaryManager(
eval_summary_writer, tf.summary.scalar, global_step=self.global_step)

self.eval_steps = eval_steps
self.eval_interval = eval_interval

# Restore Model if needed.
if self.checkpoint_manager is not None:
model_restored = self._restore_model()
Expand All @@ -150,10 +157,6 @@ def __init__(
checkpoint_number=self.global_step)
logging.info("Saved checkpoins in %s", ckpt_path)

# Create and initialize the interval triggers.
self.eval_trigger = utils.IntervalTrigger(self.eval_interval,
self.global_step.numpy())

def _restore_model(self, checkpoint_path=None):
"""Restore or initialize the model.
Expand Down Expand Up @@ -186,11 +189,12 @@ def _evaluate_once(self, current_step):
self._log_info(info)

self.eval_summary_manager.write_summaries(eval_outputs)
self.eval_summary_manager.flush()

def _maybe_save_checkpoints(self, current_step, force_trigger=False):
if self.checkpoint_manager.checkpoint_interval:
ckpt_path = self.checkpoint_manager.save(
checkpoint_number=current_step, check_interval=force_trigger)
checkpoint_number=current_step, check_interval=not force_trigger)
if ckpt_path is not None:
logging.info("Saved checkpoins in %s", ckpt_path)

Expand Down Expand Up @@ -265,6 +269,7 @@ def train(self, evaluate=True):
self._maybe_evaluate(current_step)

self.summary_manager.write_summaries(train_outputs, always_write=True)
self.summary_manager.flush()
self._maybe_save_checkpoints(current_step, force_trigger=True)
if evaluate:
self._maybe_evaluate(current_step, force_trigger=True)
Expand Down
262 changes: 262 additions & 0 deletions official/staging/training/controller_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.staging.training.controller."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

from absl.testing import parameterized
import numpy as np
import tensorflow as tf

from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.staging.training import controller
from official.staging.training import standard_runnable


def all_strategy_combinations():
"""Gets combinations of distribution strategies."""
return combinations.combine(
strategy=[
strategy_combinations.one_device_strategy,
strategy_combinations.tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
mode="eager",
)


def create_model():
x = tf.keras.layers.Input(shape=(3,), name="input")
y = tf.keras.layers.Dense(4, name="dense")(x)
model = tf.keras.Model(x, y)
return model


def summaries_with_matching_keyword(keyword, summary_dir):
"""Yields summary protos matching given keyword from event file."""
event_paths = tf.io.gfile.glob(os.path.join(summary_dir, "events*"))
for event in tf.compat.v1.train.summary_iterator(event_paths[-1]):
if event.summary is not None:
for value in event.summary.value:
if keyword in value.tag:
tf.compat.v1.logging.error(event)
yield event.summary


def check_eventfile_for_keyword(keyword, summary_dir):
"""Checks event files for the keyword."""
return any(summaries_with_matching_keyword(keyword, summary_dir))


def dataset_fn(ctx):
del ctx
inputs = np.zeros((10, 3), dtype=np.float32)
targets = np.zeros((10, 4), dtype=np.float32)
dataset = tf.data.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
dataset = dataset.batch(10, drop_remainder=True)
return dataset


class TestRunnable(standard_runnable.StandardTrainable,
standard_runnable.StandardEvaluable):
"""Implements the training and evaluation APIs for the test model."""

def __init__(self):
standard_runnable.StandardTrainable.__init__(self)
standard_runnable.StandardEvaluable.__init__(self)
self.strategy = tf.distribute.get_strategy()
self.model = create_model()
self.optimizer = tf.keras.optimizers.RMSprop()
self.global_step = self.optimizer.iterations
self.train_loss = tf.keras.metrics.Mean("train_loss", dtype=tf.float32)
self.eval_loss = tf.keras.metrics.Mean("eval_loss", dtype=tf.float32)

def build_train_dataset(self):
return self.strategy.experimental_distribute_datasets_from_function(
dataset_fn)

def train_step(self, iterator):

def _replicated_step(inputs):
"""Replicated training step."""
inputs, targets = inputs
with tf.GradientTape() as tape:
outputs = self.model(inputs)
loss = tf.math.reduce_sum(outputs - targets)
grads = tape.gradient(loss, self.model.variables)
self.optimizer.apply_gradients(zip(grads, self.model.variables))
self.train_loss.update_state(loss)

self.strategy.run(_replicated_step, args=(next(iterator),))

def train_loop_end(self):
return {
"loss": self.train_loss.result(),
}

def build_eval_dataset(self):
return self.strategy.experimental_distribute_datasets_from_function(
dataset_fn)

def eval_begin(self):
self.eval_loss.reset_states()

def eval_step(self, iterator):

def _replicated_step(inputs):
"""Replicated evaluation step."""
inputs, targets = inputs
outputs = self.model(inputs)
loss = tf.math.reduce_sum(outputs - targets)
self.eval_loss.update_state(loss)

self.strategy.run(_replicated_step, args=(next(iterator),))

def eval_end(self):
return {
"eval_loss": self.eval_loss.result(),
}


class ControllerTest(tf.test.TestCase, parameterized.TestCase):

def setUp(self):
super(ControllerTest, self).setUp()
self.model_dir = self.get_temp_dir()

@combinations.generate(all_strategy_combinations())
def test_train_and_evaluate(self, strategy):
with strategy.scope():
test_runnable = TestRunnable()

checkpoint = tf.train.Checkpoint(
model=test_runnable.model, optimizer=test_runnable.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runnable.global_step,
checkpoint_interval=10)
test_controller = controller.Controller(
strategy=strategy,
train_fn=test_runnable.train,
eval_fn=test_runnable.evaluate,
global_step=test_runnable.global_step,
train_steps=10,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
summary_interval=2,
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
eval_steps=2,
eval_interval=5)
test_controller.train(evaluate=True)

# Checkpoints are saved.
self.assertNotEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))

# Loss and accuracy values should be written into summaries.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
self.assertTrue(
check_eventfile_for_keyword(
"loss", os.path.join(self.model_dir, "summaries/train")))
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
self.assertTrue(
check_eventfile_for_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))

@combinations.generate(all_strategy_combinations())
def test_train_only(self, strategy):
with strategy.scope():
test_runnable = TestRunnable()

checkpoint = tf.train.Checkpoint(
model=test_runnable.model, optimizer=test_runnable.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runnable.global_step,
checkpoint_interval=10)
test_controller = controller.Controller(
strategy=strategy,
train_fn=test_runnable.train,
global_step=test_runnable.global_step,
train_steps=10,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
summary_interval=2,
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
)
test_controller.train(evaluate=False)

# Checkpoints are saved.
self.assertNotEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))

# Only train summaries are written.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
self.assertTrue(
check_eventfile_for_keyword(
"loss", os.path.join(self.model_dir, "summaries/train")))
self.assertFalse(
tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/eval")))

@combinations.generate(all_strategy_combinations())
def test_evaluate_only(self, strategy):
with strategy.scope():
test_runnable = TestRunnable()

checkpoint = tf.train.Checkpoint(model=test_runnable.model)
checkpoint.save(os.path.join(self.model_dir, "ckpt"))

checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runnable.global_step)
test_controller = controller.Controller(
strategy=strategy,
eval_fn=test_runnable.evaluate,
global_step=test_runnable.global_step,
checkpoint_manager=checkpoint_manager,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
eval_steps=2,
eval_interval=5)
test_controller.evaluate()

# Only eval summaries are written
self.assertFalse(
tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/train")))
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
self.assertTrue(
check_eventfile_for_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))


if __name__ == "__main__":
tf.test.main()
5 changes: 5 additions & 0 deletions official/staging/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ def summary_writer(self):
"""Returns the underlying summary writer."""
return self._summary_writer

def flush(self):
"""Flush the underlying summary writer."""
if self._enabled:
tf.summary.flush(self._summary_writer)

def write_summaries(self, items, always_write=True):
"""Write a bulk of summaries.
Expand Down
Loading

0 comments on commit febaae9

Please sign in to comment.