Skip to content

Commit

Permalink
[Orbit] The global_step variable is required and cannot be None.
Browse files Browse the repository at this point in the history
We use kwargs for all arguments and enforce them should be kwargs.
Reorder args.

PiperOrigin-RevId: 346908658
  • Loading branch information
saberkun authored and tensorflower-gardener committed Dec 11, 2020
1 parent 8de71c4 commit 5533122
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 19 deletions.
2 changes: 1 addition & 1 deletion official/core/train_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
checkpoint_manager = None

controller = orbit.Controller(
distribution_strategy,
strategy=distribution_strategy,
trainer=trainer if 'train' in mode else None,
evaluator=trainer,
global_step=trainer.global_step,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,9 @@ def run(flags_obj):
checkpoint_interval=checkpoint_interval)

resnet_controller = orbit.Controller(
strategy,
runnable,
runnable if not flags_obj.skip_eval else None,
strategy=strategy,
trainer=runnable,
evaluator=runnable if not flags_obj.skip_eval else None,
global_step=runnable.global_step,
steps_per_loop=steps_per_loop,
checkpoint_manager=checkpoint_manager,
Expand Down
27 changes: 12 additions & 15 deletions orbit/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,11 @@ class Controller:

def __init__(
self,
strategy: Optional[tf.distribute.Strategy] = None,
*, # Makes all args keyword only.
global_step: tf.Variable,
trainer: Optional[runner.AbstractTrainer] = None,
evaluator: Optional[runner.AbstractEvaluator] = None,
global_step: Optional[tf.Variable] = None,
strategy: Optional[tf.distribute.Strategy] = None,
# Train related
steps_per_loop: Optional[int] = None,
checkpoint_manager: Optional[tf.train.CheckpointManager] = None,
Expand All @@ -93,13 +94,6 @@ def __init__(
recent checkpoint during this `__init__` method.
Args:
strategy: An instance of `tf.distribute.Strategy`. If not provided, the
strategy will be initialized from the current in-scope strategy using
`tf.distribute.get_strategy()`.
trainer: An instance of `orbit.AbstractTrainer`, which implements the
inner training loop.
evaluator: An instance of `orbit.AbstractEvaluator`, which implements
evaluation.
global_step: An integer `tf.Variable` storing the global training step
number. Usually this can be obtained from the `iterations` property of
the model's optimizer (e.g. `trainer.optimizer.iterations`). In cases
Expand All @@ -109,6 +103,13 @@ def __init__(
recommended to create the `tf.Variable` inside the distribution strategy
scope, with `aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA` (see
also `orbit.utils.create_global_step()`).
trainer: An instance of `orbit.AbstractTrainer`, which implements the
inner training loop.
evaluator: An instance of `orbit.AbstractEvaluator`, which implements
evaluation.
strategy: An instance of `tf.distribute.Strategy`. If not provided, the
strategy will be initialized from the current in-scope strategy using
`tf.distribute.get_strategy()`.
steps_per_loop: The number of steps to run in each inner loop of training
(passed as the `num_steps` parameter of `trainer.train`).
checkpoint_manager: An instance of `tf.train.CheckpointManager`. If
Expand Down Expand Up @@ -137,7 +138,6 @@ def __init__(
"""
if trainer is None and evaluator is None:
raise ValueError("`trainer` and `evaluator` should not both be `None`.")

if trainer is not None:
if steps_per_loop is None:
raise ValueError(
Expand All @@ -155,9 +155,7 @@ def __init__(
f"`summary interval` ({summary_interval}) must be a multiple "
f"of `steps_per_loop` ({steps_per_loop}).")

if global_step is None:
raise ValueError("`global_step` is required.")
elif not isinstance(global_step, tf.Variable):
if not isinstance(global_step, tf.Variable):
raise ValueError("`global_step` must be a `tf.Variable`.")

self.trainer = trainer
Expand Down Expand Up @@ -185,8 +183,7 @@ def __init__(
self.eval_summary_manager = utils.SummaryManager(
eval_summary_dir, tf.summary.scalar, global_step=self.global_step)

if self.global_step is not None:
tf.summary.experimental.set_step(self.global_step)
tf.summary.experimental.set_step(self.global_step)

# Restores the model if needed.
if self.checkpoint_manager is not None:
Expand Down

0 comments on commit 5533122

Please sign in to comment.