Skip to content

Commit

Permalink
Implement gather/all_gather for DefaultStrategy and OneDeviceStrategy.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 335796948
Change-Id: I0fdbf629ae599f7d88967c23d3bb3c5dd1b0b0f6
  • Loading branch information
w-xinyi authored and tensorflower-gardener committed Oct 7, 2020
1 parent ee4be8e commit 6876c21
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 15 deletions.
6 changes: 5 additions & 1 deletion tensorflow/python/distribute/distribute_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -3193,7 +3193,7 @@ def batch_all_gather(strategy, *value_flat):
def grad_wrapper(*xs):
ys = self.merge_call(batch_all_gather, args=xs)
# The gradient of an all-gather is itself an all-gather.
return ys, lambda *dy_s: self.all_gather(dy_s, axis)
return ys, lambda *dy_s: self._all_gather(dy_s, axis)

return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value)))

Expand Down Expand Up @@ -3346,6 +3346,10 @@ def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
del reduce_op, destinations, experimental_hints
return value

def _gather_to_implementation(self, value, destinations, axis, experimental_hints):
del destinations, axis, experimental_hints
return value

def _update(self, var, fn, args, kwargs, group):
# The implementations of _update() and _update_non_slot() are identical
# except _update() passes `var` as the first argument to `fn()`.
Expand Down
5 changes: 5 additions & 0 deletions tensorflow/python/distribute/one_device_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,11 @@ def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
del reduce_op, destinations, experimental_hints
return value

def _gather_to_implementation(self, value, destinations, axis,
experimental_hints):
del destinations, axis, experimental_hints
return value

def _update(self, var, fn, args, kwargs, group):
# The implementations of _update() and _update_non_slot() are identical
# except _update() passes `var` as the first argument to `fn()`.
Expand Down
47 changes: 33 additions & 14 deletions tensorflow/python/distribute/strategy_common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ def fn():
@combinations.generate(
combinations.combine(
strategy=[
strategy_combinations.default_strategy,
strategy_combinations.one_device_strategy,
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
Expand Down Expand Up @@ -238,8 +241,13 @@ def run():

def testGatherRaiseDiffShapeAtNonAxis(self, strategy, pure_eager):
"""Different at non-`axis`-th dimension : [1, 1], [1, 2], 0th -> raise error."""
if _get_num_devices_per_worker(strategy) > 1:
if isinstance(strategy, CollectiveAllReduceStrategy
) and _get_num_replicas_per_client(strategy) > 1:
self.skipTest('b/167331966')

if strategy.num_replicas_in_sync <= 1:
self.skipTest('Test for more than 1 replica only.')

def value_fn(ctx):
return constant_op.constant(
1, shape=(1, ctx.replica_id_in_sync_group + 1))
Expand Down Expand Up @@ -284,7 +292,8 @@ def testGatherRaiseDifferentRank(self, strategy, pure_eager):
"""Different rank: [1,], [1, 2] -> raise error."""
if strategy.num_replicas_in_sync <= 1:
self.skipTest('Test for more than 1 replicas.')
if _get_num_devices_per_worker(strategy) > 1:
if isinstance(strategy, CollectiveAllReduceStrategy
) and _get_num_replicas_per_client(strategy) > 1:
self.skipTest('b/167331966')
def value_fn(ctx):
return array_ops.ones(shape=(range(1, ctx.replica_id_in_sync_group + 2)))
Expand All @@ -308,6 +317,9 @@ def run():
@combinations.generate(
combinations.combine(
strategy=[
strategy_combinations.default_strategy,
strategy_combinations.one_device_strategy,
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
Expand All @@ -334,7 +346,7 @@ def replica_fn(per_replica_value):

all_value = [value_on_replica for _ in range(strategy.num_replicas_in_sync)]
expect = array_ops.concat(all_value, axis=axis)
expected_result = [expect] * _get_num_devices_per_worker(strategy)
expected_result = [expect] * _get_num_replicas_per_client(strategy)

self.assertAllClose(result, expected_result)

Expand Down Expand Up @@ -409,7 +421,7 @@ def run(value):
if not pure_eager:
run = def_function.function(run)

expected_result = [expect] * _get_num_devices_per_worker(strategy)
expected_result = [expect] * _get_num_replicas_per_client(strategy)
result = strategy.experimental_local_results(
strategy.run(run, args=(per_replica_value,)))
self.assertAllEqual(result, expected_result)
Expand Down Expand Up @@ -443,7 +455,7 @@ def run(value):
if not pure_eager:
run = def_function.function(run)

expected_result = [expect] * _get_num_devices_per_worker(strategy)
expected_result = [expect] * _get_num_replicas_per_client(strategy)
result = strategy.experimental_local_results(
strategy.run(run, args=(per_replica_value,)))
self.assertAllEqual(result, expected_result)
Expand All @@ -469,7 +481,7 @@ def value_fn(ctx):
# 1, shape=(1, sum(range(strategy.num_replicas_in_sync + 1))))
raise ValueError('Add your own expect according to num_replicas_in sync')

expected_per_replica_1 = [expect_1] * _get_num_devices_per_worker(strategy)
expected_per_replica_1 = [expect_1] * _get_num_replicas_per_client(strategy)

value_2 = constant_op.constant([[[1, 2], [1, 2]]])

Expand All @@ -485,7 +497,7 @@ def value_fn(ctx):
# [value_2 for _ in range(strategy.num_replicas_in_sync)], axis=axis)
raise ValueError('Add your own expect according to num_replicas_in sync')

expected_per_replica_2 = [expect_2] * _get_num_devices_per_worker(strategy)
expected_per_replica_2 = [expect_2] * _get_num_replicas_per_client(strategy)

def run(value):
value_1 = array_ops.identity(value)
Expand Down Expand Up @@ -517,7 +529,7 @@ def run():

all_value = [single_value for _ in range(strategy.num_replicas_in_sync)]
expect = array_ops.concat(all_value, axis=axis)
expected_per_replica = [expect] * _get_num_devices_per_worker(strategy)
expected_per_replica = [expect] * _get_num_replicas_per_client(strategy)

result = strategy.run(run)
for gathered_result in result:
Expand All @@ -527,9 +539,13 @@ def run():

def testAllGatherRaiseDiffShapeAtNonAxis(self, strategy, pure_eager):
"""Different at non-`axis`-th dimension : [2, 1], [1, 1], all_gather(...axis=1...) -> raise error."""
if _get_num_devices_per_worker(strategy) > 1:
if isinstance(strategy, CollectiveAllReduceStrategy
) and _get_num_replicas_per_client(strategy) > 1:
self.skipTest('b/167331966')

if strategy.num_replicas_in_sync <= 1:
self.skipTest('Test for more than 1 replica only.')

def value_fn(ctx):
return constant_op.constant(
1, shape=(1, ctx.replica_id_in_sync_group + 1))
Expand Down Expand Up @@ -571,7 +587,8 @@ def testAllGatherRaiseDifferentRank(self, strategy, pure_eager):
"""Different rank: [1,], [1, 2] -> raise error."""
if strategy.num_replicas_in_sync <= 1:
self.skipTest('Test for more than 1 replicas.')
if _get_num_devices_per_worker(strategy) > 1:
if isinstance(strategy, CollectiveAllReduceStrategy
) and _get_num_replicas_per_client(strategy) > 1:
self.skipTest('b/167331966')
def value_fn(ctx):
return array_ops.ones(shape=(range(1, ctx.replica_id_in_sync_group + 2)))
Expand Down Expand Up @@ -601,10 +618,12 @@ def _make_indexed_slices(values, indices, dense_shape):
return tensor


def _get_num_devices_per_worker(strategy):
"""Returns the number of workers in the current cluster for multi-worker."""
resolver = strategy.cluster_resolver
return max(nest.flatten(resolver.num_accelerators())[0], 1)
def _get_num_replicas_per_client(strategy):
if isinstance(strategy, CollectiveAllReduceStrategy):
resolver = strategy.cluster_resolver
return max(nest.flatten(resolver.num_accelerators())[0], 1)
else:
return strategy.num_replicas_in_sync


@combinations.generate(
Expand Down

0 comments on commit 6876c21

Please sign in to comment.