Skip to content

Commit

Permalink
Clean up all matches to the TPUStrategy to easily trace things we nee…
Browse files Browse the repository at this point in the history
…d to fix.

PiperOrigin-RevId: 222188691
  • Loading branch information
Sourabh Bajaj authored and tensorflower-gardener committed Nov 20, 2018
1 parent 65ab9a8 commit 3bb55e9
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 13 deletions.
12 changes: 9 additions & 3 deletions tensorflow/python/keras/engine/distributed_training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def configure_and_create_session(distribution_strategy):
# TODO(priyag): Throw error if a session already exists.
session_config = K.get_default_session_config()

if type(distribution_strategy).__name__ == 'TPUStrategy':
if is_tpu_strategy(distribution_strategy):
# TODO(priyag, yuefengz): Remove this workaround when Distribute
# Coordinator is integrated with keras and we can create a session from
# there.
Expand Down Expand Up @@ -379,7 +379,7 @@ def validate_inputs(x, y, distribution_strategy):
'Iterator. You must pass a `tf.data.Dataset` object or a '
'numpy array as input.')

if distribution_strategy.__class__.__name__ == 'TPUStrategy':
if is_tpu_strategy(distribution_strategy):
for i in [x, y]:
if isinstance(i, dataset_ops.Dataset):
shapes = nest.flatten(i.output_shapes)
Expand All @@ -401,6 +401,12 @@ def global_batch_size_supported(distribution_strategy):
return strategy_name in ('TPUStrategy', 'CoreMirroredStrategy')


# TODO(sourabhbajaj): Remove this once we use the same API for all strategies.
def is_tpu_strategy(strategy):
"""We're executing TPU Strategy."""
return strategy is not None and strategy.__class__.__name__ == 'TPUStrategy'


def get_input_params(distribution_strategy, first_x_value, steps, batch_size,
is_training=False):
"""Calculate the number of batches and steps/steps_per_epoch.
Expand Down Expand Up @@ -504,7 +510,7 @@ def get_cpu_device(distribution_strategy):
NotImplementedError: We currently don't support copying numpy data to
multiple hosts in the case of Cloud TPU pods.
"""
if distribution_strategy.__class__.__name__ == 'TPUStrategy':
if is_tpu_strategy(distribution_strategy):
if distribution_strategy.extended.num_hosts > 1:
raise NotImplementedError('TPUDistributionStrategy does not '
'support numpy inputs when running on Cloud'
Expand Down
12 changes: 8 additions & 4 deletions tensorflow/python/keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,8 @@ def _distribution_standardize_user_data(self,
'when using DistributionStrategy.')

if (sample_weight is not None and sample_weight.all() and
self._distribution_strategy.__class__.__name__ == 'TPUStrategy'):
distributed_training_utils.is_tpu_strategy(
self._distribution_strategy)):
raise NotImplementedError('`sample_weight` is currently not supported '
'when using TPUStrategy.')

Expand Down Expand Up @@ -1755,7 +1756,8 @@ def fit(self,
initial_epoch=initial_epoch,
steps_per_epoch=steps_per_epoch,
validation_steps=validation_steps)
elif training_distributed.should_run_experimental_loop(self):
elif distributed_training_utils.is_tpu_strategy(
self._distribution_strategy):
return training_distributed.experimental_fit_loop(
self,
x,
Expand Down Expand Up @@ -1916,7 +1918,8 @@ def evaluate(self,
batch_size=batch_size,
verbose=verbose,
steps=steps)
elif training_distributed.should_run_experimental_loop(self):
elif distributed_training_utils.is_tpu_strategy(
self._distribution_strategy):
return training_distributed.experimental_test_loop(
self, iterator=x, verbose=verbose, steps=steps)
elif isinstance(x, iterator_ops.EagerIterator):
Expand Down Expand Up @@ -2026,7 +2029,8 @@ def predict(self,
if self.run_eagerly:
return training_eager.predict_loop(
self, x, batch_size=batch_size, verbose=verbose, steps=steps)
elif training_distributed.should_run_experimental_loop(self):
elif distributed_training_utils.is_tpu_strategy(
self._distribution_strategy):
return training_distributed.experimental_predict_loop(
self, x, verbose=verbose, steps=steps)
elif isinstance(x, iterator_ops.EagerIterator):
Expand Down
6 changes: 0 additions & 6 deletions tensorflow/python/keras/engine/training_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,9 +674,3 @@ def _per_device_aggregate_batch(batch_outs, model, mode):
total_batch_outs.append(np.concatenate(nest.flatten(nested_outs)))
return total_batch_outs
return batch_outs


def should_run_experimental_loop(model):
"""Whether to run the experimental loops in this file."""
return (hasattr(model, '_distribution_strategy') and
model._distribution_strategy.__class__.__name__ == 'TPUStrategy')

0 comments on commit 3bb55e9

Please sign in to comment.