Skip to content

Commit

Permalink
Install _distributed_container only at variable creation
Browse files Browse the repository at this point in the history
With a following change we're going to create new DistributedVariable as return
value of assign*() and scatter*(). Installing _distributed_container multiple
times will be messy.

PiperOrigin-RevId: 307753201
Change-Id: I3c87abc301ea32b0169034324a108d6967229889
  • Loading branch information
crccw authored and tensorflower-gardener committed Apr 22, 2020
1 parent 4137c32 commit 79abfee
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 28 deletions.
17 changes: 10 additions & 7 deletions tensorflow/python/distribute/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,10 +433,6 @@ def __init__(self, strategy, values, aggregation):
self._aggregation = aggregation
super(DistributedVariable, self).__init__(values)
self._common_name = self._primary.name.split(":")[0]
# Use a weakref to make it easy to map from the contained values
# to the container without introducing a reference cycle.
for v in values:
v._distributed_container = weakref.ref(self) # pylint: disable=protected-access
# tf.keras keeps track of variables initialized using this attribute. When
# tf.keras gets the default session, it initializes all uninitialized vars.
# We need to make _keras_initialized a member of DistributedVariable because
Expand Down Expand Up @@ -774,6 +770,13 @@ def create_mirrored_variable( # pylint: disable=missing-docstring
value_list = real_mirrored_creator(**kwargs)
var_cls = sync_on_read_cls if is_sync_on_read else mirrored_cls
result = var_cls(strategy, value_list, aggregation)
# Install the created DistributedVariable as _distributed_container property
# of the underlying variables, to make it easy to map back to the container.
for v in result.values:
# Hold a strong reference to avoid the container from being GC-ed. After
# v = v.assign(), the user code may no longer holds references to the
# original container, since v.assign() returns a new DistributedVariable.
v._distributed_container = result # pylint: disable=protected-access

# Add the wrapped variable to the requested collections.
# The handling of eager mode and the global step matches
Expand Down Expand Up @@ -1240,10 +1243,10 @@ def regroup(values, wrap_class=PerReplica, always_wrap=False):
# pylint: disable=protected-access
assert not isinstance(v0, MirroredVariable), (
"ids = %s, values = %s" % ([id(v) for v in values], values))
distributed_container = v0._distributed_container()
distributed_container = v0._distributed_container
assert distributed_container is not None
for v in values[1:]:
assert distributed_container is v._distributed_container()
assert distributed_container is v._distributed_container
return distributed_container
# pylint: enable=protected-access

Expand Down Expand Up @@ -1331,7 +1334,7 @@ def value_container(val):
# DistributedVariable has _distributed_container defined
# but we don't want to return it.
not isinstance(val, DistributedVariable)):
container = val._distributed_container() # pylint: disable=protected-access
container = val._distributed_container # pylint: disable=protected-access
if container is not None:
return container
return val
Expand Down
35 changes: 16 additions & 19 deletions tensorflow/python/distribute/values_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def _make_mirrored():
return mirrored


class RegroupAndSelectDeviceTest(test.TestCase):
class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):

def _is_per_replica(self, result, expected, klass=values.PerReplica):
self.assertIsInstance(result, klass)
Expand Down Expand Up @@ -448,12 +448,20 @@ def testWrapAListOfTwoTuples(self):
self._is_per_replica(result[0], ("1", "3"), values.PerReplica)
self._is_per_replica(result[1], ("2", "4"), values.PerReplica)

def testMirroredContainer(self):
if context.num_gpus() < 1 and context.executing_eagerly():
self.skipTest("A GPU is not available for this test in eager mode.")
mirrored = _make_mirrored()
result = values.regroup(mirrored.values)
self.assertIs(mirrored, result)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.mirrored_strategy_with_one_cpu,
],
mode=["graph", "eager"],
))
def testMirroredContainer(self, distribution):
with distribution.scope():
v = variable_scope.variable(
1., aggregation=variable_scope.VariableAggregation.SUM)
self.assertTrue(values.is_distributed_variable(v))
self.assertTrue(values.is_distributed_variable(values.regroup(v.values)))

def testSameId(self):
foo = object()
Expand All @@ -479,18 +487,7 @@ def testOneDevice(self):
result = values.regroup((_nested_value("1"),))
# On one device regroup() and select_replica() are basically identity.
self.assertEqual(_nested_value("1"), result)
self.assertEqual(_nested_value("1"),
values.select_replica(0, result))

# The one exception has to do with MirroredVariables.
d = "/device:CPU:0"
with ops.device(d):
v = variable_scope.get_variable(
name="v", initializer=1., use_resource=True)
mirrored = values.MirroredVariable(None, (v,),
variable_scope.VariableAggregation.SUM)
result = values.regroup((v,))
self.assertIs(mirrored, result)
self.assertEqual(_nested_value("1"), values.select_replica(0, result))

def testNamedTuple(self):

Expand Down
2 changes: 1 addition & 1 deletion tensorflow/python/keras/optimizer_v2/optimizer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,7 +1274,7 @@ def _var_key(var):
# pylint: disable=protected-access
# Get the distributed variable if it exists.
if hasattr(var, "_distributed_container"):
var = var._distributed_container()
var = var._distributed_container
if var._in_graph_mode:
return var._shared_name
return var._unique_id
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/python/training/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,7 @@ def get_slot(self, var, name):
if hasattr(var, "_distributed_container"):
# NOTE: If this isn't patched, then there is no `handle` in
# `_resource_apply_dense`.
distributed_container = var._distributed_container()
distributed_container = var._distributed_container
assert distributed_container is not None
if ops.executing_eagerly_outside_functions():
key = distributed_container._unique_id
Expand Down

0 comments on commit 79abfee

Please sign in to comment.