Skip to content

Commit

Permalink
fix flags to force_v2_in_keras_compile (tensorflow#7287)
Browse files Browse the repository at this point in the history
  • Loading branch information
tfboyd authored Jul 24, 2019
1 parent 829190e commit d09994b
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions official/resnet/keras/keras_imagenet_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,8 @@ def benchmark_1_gpu_no_dist_strat_run_eagerly(self):
FLAGS.batch_size = 64
self._run_and_report_benchmark()

def benchmark_1_gpu_force_dist_strat_run_eagerly(self):
"""No dist strat but forced ds tf.compile path and force eager."""
def benchmark_1_gpu_no_dist_strat_force_v2_run_eagerly(self):
"""Forced v2 execution in tf.compile path and force eager."""
self._setup()

FLAGS.num_gpus = 1
Expand All @@ -274,11 +274,11 @@ def benchmark_1_gpu_force_dist_strat_run_eagerly(self):
FLAGS.model_dir = self._get_model_dir(
'benchmark_1_gpu_force_dist_strat_run_eagerly')
FLAGS.batch_size = 64
FLAGS.force_run_distributed = True
FLAGS.force_v2_in_keras_compile = True
self._run_and_report_benchmark()

def benchmark_1_gpu_force_dist_strat(self):
"""No dist strat but forced ds tf.compile path."""
def benchmark_1_gpu_no_dist_strat_force_v2(self):
"""No dist strat but forced v2 execution tf.compile path."""
self._setup()

FLAGS.num_gpus = 1
Expand All @@ -287,7 +287,7 @@ def benchmark_1_gpu_force_dist_strat(self):
FLAGS.model_dir = self._get_model_dir(
'benchmark_1_gpu_force_dist_strat')
FLAGS.batch_size = 128
FLAGS.force_run_distributed = True
FLAGS.force_v2_in_keras_compile = True
self._run_and_report_benchmark()

def benchmark_1_gpu_no_dist_strat_run_eagerly_fp16(self):
Expand Down

0 comments on commit d09994b

Please sign in to comment.