diff --git a/tutorials/image/cifar10_estimator/cifar10_main.py b/tutorials/image/cifar10_estimator/cifar10_main.py index 5a01b4c82ee..8bd94e203ef 100644 --- a/tutorials/image/cifar10_estimator/cifar10_main.py +++ b/tutorials/image/cifar10_estimator/cifar10_main.py @@ -344,11 +344,10 @@ def _experiment_fn(run_config, hparams): train_steps = hparams.train_steps eval_steps = num_eval_examples // hparams.eval_batch_size - - num_workers = run_config.num_worker_replicas - + classifier = tf.estimator.Estimator( - model_fn=get_model_fn(num_gpus, variable_strategy, num_workers), + model_fn=get_model_fn(num_gpus, variable_strategy, + run_config.num_worker_replicas or 1), config=run_config, params=hparams)