Skip to content

Commit

Permalink
Migrating away from Experiment class, as it is now deprecated. Also, …
Browse files Browse the repository at this point in the history
…refactoring into a separate model library and binaries.

PiperOrigin-RevId: 192004845
  • Loading branch information
pkulzc committed Apr 13, 2018
1 parent 8deba73 commit 90cc9ba
Show file tree
Hide file tree
Showing 6 changed files with 576 additions and 459 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Creates and runs `Experiment` for object detection model.
This uses the TF.learn framework to define and run an object detection model
wrapped in an `Estimator`.
Note that this module is only compatible with SSD Meta architecture at the
moment.
"""
r"""Constructs model, inputs, and training environment."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools
import os

import tensorflow as tf

from google.protobuf import text_format
from tensorflow.contrib.learn.python.learn import learn_runner
from tensorflow.contrib.tpu.python.tpu import tpu_estimator
from tensorflow.contrib.tpu.python.tpu import tpu_optimizer
from tensorflow.python.lib.io import file_io
from object_detection import eval_util
from object_detection import inputs
from object_detection import model_hparams
from object_detection.builders import model_builder
from object_detection.builders import optimizer_builder
from object_detection.core import standard_fields as fields
Expand All @@ -45,15 +35,6 @@
from object_detection.utils import variables_helper
from object_detection.utils import visualization_utils as vis_utils

tf.flags.DEFINE_string('model_dir', None, 'Path to output model directory '
'where event and checkpoint files will be written.')
tf.flags.DEFINE_string('pipeline_config_path', None, 'Path to pipeline config '
'file.')
tf.flags.DEFINE_integer('num_train_steps', None, 'Number of train steps.')
tf.flags.DEFINE_integer('num_eval_steps', None, 'Number of train steps.')
FLAGS = tf.flags.FLAGS


# A map of names to methods that help build the model.
MODEL_BUILD_UTIL_MAP = {
'get_configs_from_pipeline_file':
Expand Down Expand Up @@ -406,33 +387,18 @@ def tpu_scaffold():
return model_fn


def build_experiment_fn(train_steps, eval_steps):
"""Returns a function that creates an `Experiment`."""

def build_experiment(run_config, hparams):
"""Builds an `Experiment` from configuration and hyperparameters.
Args:
run_config: A `RunConfig`.
hparams: A `HParams`.
Returns:
An `Experiment` object.
"""
return populate_experiment(run_config, hparams, FLAGS.pipeline_config_path,
train_steps, eval_steps)

return build_experiment


def populate_experiment(run_config,
hparams,
pipeline_config_path,
train_steps=None,
eval_steps=None,
model_fn_creator=create_model_fn,
**kwargs):
"""Populates an `Experiment` object.
def create_estimator_and_inputs(run_config,
hparams,
pipeline_config_path,
train_steps=None,
eval_steps=None,
model_fn_creator=create_model_fn,
use_tpu_estimator=False,
use_tpu=False,
num_shards=1,
params=None,
**kwargs):
"""Creates `Estimator`, input functions, and steps.
Args:
run_config: A `RunConfig`.
Expand All @@ -452,18 +418,33 @@ def populate_experiment(run_config,
* Returns:
`model_fn` for `Estimator`.
use_tpu_estimator: Whether a `TPUEstimator` should be returned. If False,
an `Estimator` will be returned.
use_tpu: Boolean, whether training and evaluation should run on TPU. Only
used if `use_tpu_estimator` is True.
num_shards: Number of shards (TPU cores). Only used if `use_tpu_estimator`
is True.
params: Parameter dictionary passed from the estimator. Only used if
`use_tpu_estimator` is True.
**kwargs: Additional keyword arguments for configuration override.
Returns:
An `Experiment` that defines all aspects of training, evaluation, and
export.
A dictionary with the following fields:
'estimator': An `Estimator` or `TPUEstimator`.
'train_input_fn': A training input function.
'eval_input_fn': An evaluation input function.
'predict_input_fn': A prediction input function.
'train_steps': Number of training steps. Either directly from input or from
configuration.
'eval_steps': Number of evaluation steps. Either directly from input or from
configuration.
"""
get_configs_from_pipeline_file = MODEL_BUILD_UTIL_MAP[
'get_configs_from_pipeline_file']
create_pipeline_proto_from_configs = MODEL_BUILD_UTIL_MAP[
'create_pipeline_proto_from_configs']
merge_external_params_with_configs = MODEL_BUILD_UTIL_MAP[
'merge_external_params_with_configs']
create_pipeline_proto_from_configs = MODEL_BUILD_UTIL_MAP[
'create_pipeline_proto_from_configs']
create_train_input_fn = MODEL_BUILD_UTIL_MAP['create_train_input_fn']
create_eval_input_fn = MODEL_BUILD_UTIL_MAP['create_eval_input_fn']
create_predict_input_fn = MODEL_BUILD_UTIL_MAP['create_predict_input_fn']
Expand All @@ -481,16 +462,16 @@ def populate_experiment(run_config,
eval_config = configs['eval_config']
eval_input_config = configs['eval_input_config']

if train_steps is None and train_config.num_steps:
train_steps = train_config.num_steps
if train_steps is None:
train_steps = configs['train_config'].num_steps

if eval_steps is None and eval_config.num_examples:
eval_steps = eval_config.num_examples
if eval_steps is None:
eval_steps = configs['eval_config'].num_examples

detection_model_fn = functools.partial(
model_builder.build, model_config=model_config)

# Create the input functions for TRAIN/EVAL.
# Create the input functions for TRAIN/EVAL/PREDICT.
train_input_fn = create_train_input_fn(
train_config=train_config,
train_input_config=train_input_config,
Expand All @@ -499,51 +480,149 @@ def populate_experiment(run_config,
eval_config=eval_config,
eval_input_config=eval_input_config,
model_config=model_config)
predict_input_fn = create_predict_input_fn(model_config=model_config)

model_fn = model_fn_creator(detection_model_fn, configs, hparams, use_tpu)
if use_tpu_estimator:
estimator = tpu_estimator.TPUEstimator(
model_fn=model_fn,
train_batch_size=train_config.batch_size,
# For each core, only batch size 1 is supported for eval.
eval_batch_size=num_shards * 1 if use_tpu else 1,
use_tpu=use_tpu,
config=run_config,
params=params if params else {})
else:
estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)

export_strategies = [
tf.contrib.learn.utils.saved_model_export_utils.make_export_strategy(
serving_input_fn=create_predict_input_fn(
model_config=model_config))
]

estimator = tf.estimator.Estimator(
model_fn=model_fn_creator(detection_model_fn, configs, hparams),
config=run_config)

# Write the as-run pipeline config to disk.
if run_config.is_chief:
# Store the final pipeline config for traceability.
pipeline_config_final = create_pipeline_proto_from_configs(
configs)
if not file_io.file_exists(estimator.model_dir):
file_io.recursive_create_dir(estimator.model_dir)
pipeline_config_final_path = os.path.join(estimator.model_dir,
'pipeline.config')
config_text = text_format.MessageToString(pipeline_config_final)
with tf.gfile.Open(pipeline_config_final_path, 'wb') as f:
tf.logging.info('Writing as-run pipeline config file to %s',
pipeline_config_final_path)
f.write(config_text)
config_util.save_pipeline_config(pipeline_config_final, estimator.model_dir)

return tf.contrib.learn.Experiment(
return dict(
estimator=estimator,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
predict_input_fn=predict_input_fn,
train_steps=train_steps,
eval_steps=eval_steps,
export_strategies=export_strategies,
eval_delay_secs=120,)
eval_steps=eval_steps)


def create_train_and_eval_specs(train_input_fn,
eval_input_fn,
predict_input_fn,
train_steps,
eval_steps,
eval_on_train_data=False,
final_exporter_name='Servo',
eval_spec_name='eval'):
"""Creates a `TrainSpec` and `EvalSpec`s.
Args:
train_input_fn: Function that produces features and labels on train data.
eval_input_fn: Function that produces features and labels on eval data.
predict_input_fn: Function that produces features for inference.
train_steps: Number of training steps.
eval_steps: Number of eval steps.
eval_on_train_data: Whether to evaluate model on training data. Default is
False.
final_exporter_name: String name given to `FinalExporter`.
eval_spec_name: String name given to main `EvalSpec`.
Returns:
Tuple of `TrainSpec` and list of `EvalSpecs`. The first `EvalSpec` is for
evaluation data. If `eval_on_train_data` is True, the second `EvalSpec` in
the list will correspond to training data.
"""

exporter = tf.estimator.FinalExporter(
name=final_exporter_name, serving_input_receiver_fn=predict_input_fn)

train_spec = tf.estimator.TrainSpec(
input_fn=train_input_fn, max_steps=train_steps)

eval_specs = [
tf.estimator.EvalSpec(
name=eval_spec_name,
input_fn=eval_input_fn,
steps=eval_steps,
exporters=exporter)
]

if eval_on_train_data:
eval_specs.append(
tf.estimator.EvalSpec(
name='eval_on_train', input_fn=train_input_fn, steps=eval_steps))

return train_spec, eval_specs


def main(unused_argv):
tf.flags.mark_flag_as_required('model_dir')
tf.flags.mark_flag_as_required('pipeline_config_path')
config = tf.contrib.learn.RunConfig(model_dir=FLAGS.model_dir)
learn_runner.run(
experiment_fn=build_experiment_fn(FLAGS.num_train_steps,
FLAGS.num_eval_steps),
run_config=config,
hparams=model_hparams.create_hparams())
def populate_experiment(run_config,
hparams,
pipeline_config_path,
train_steps=None,
eval_steps=None,
model_fn_creator=create_model_fn,
**kwargs):
"""Populates an `Experiment` object.
EXPERIMENT CLASS IS DEPRECATED. Please switch to
tf.estimator.train_and_evaluate. As an example, see model_main.py.
if __name__ == '__main__':
tf.app.run()
Args:
run_config: A `RunConfig`.
hparams: A `HParams`.
pipeline_config_path: A path to a pipeline config file.
train_steps: Number of training steps. If None, the number of training steps
is set from the `TrainConfig` proto.
eval_steps: Number of evaluation steps per evaluation cycle. If None, the
number of evaluation steps is set from the `EvalConfig` proto.
model_fn_creator: A function that creates a `model_fn` for `Estimator`.
Follows the signature:
* Args:
* `detection_model_fn`: Function that returns `DetectionModel` instance.
* `configs`: Dictionary of pipeline config objects.
* `hparams`: `HParams` object.
* Returns:
`model_fn` for `Estimator`.
**kwargs: Additional keyword arguments for configuration override.
Returns:
An `Experiment` that defines all aspects of training, evaluation, and
export.
"""
tf.logging.warning('Experiment is being deprecated. Please use '
'tf.estimator.train_and_evaluate(). See model_main.py for '
'an example.')
train_and_eval_dict = create_estimator_and_inputs(
run_config,
hparams,
pipeline_config_path,
train_steps=train_steps,
eval_steps=eval_steps,
model_fn_creator=model_fn_creator,
**kwargs)
estimator = train_and_eval_dict['estimator']
train_input_fn = train_and_eval_dict['train_input_fn']
eval_input_fn = train_and_eval_dict['eval_input_fn']
predict_input_fn = train_and_eval_dict['predict_input_fn']
train_steps = train_and_eval_dict['train_steps']
eval_steps = train_and_eval_dict['eval_steps']

export_strategies = [
tf.contrib.learn.utils.saved_model_export_utils.make_export_strategy(
serving_input_fn=predict_input_fn)
]

return tf.contrib.learn.Experiment(
estimator=estimator,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
train_steps=train_steps,
eval_steps=eval_steps,
export_strategies=export_strategies,
eval_delay_secs=120,)
Loading

0 comments on commit 90cc9ba

Please sign in to comment.