Skip to content

Commit

Permalink
Make the currently-internal variable scope shim for TF2 ignore `reuse…
Browse files Browse the repository at this point in the history
…=False`, because it needs to work with code that was written to include some reuse=False but is now needed to be run multiple times.

PiperOrigin-RevId: 382130441
  • Loading branch information
Tomer Kaftan authored and tensorflower-gardener committed Jun 29, 2021
1 parent b1b50c4 commit 0ba7292
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 77 deletions.
13 changes: 5 additions & 8 deletions keras/legacy_tf_layers/variable_scope_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ class _EagerVariableStore(object):
interaction between `tf.function` `FuncGraph` internals, Keras
Functional Models, and TPUStrategy variable initialization.
Also, it always acts as if reuse is set to either "TRUE" or
tf.compat.v1.AUTO_REUSE
Attributes:
vars: a dictionary with string names (same as passed in GetVar) as keys and
the corresponding TensorFlow Variables as values.
Expand Down Expand Up @@ -171,9 +174,9 @@ def get_variable(
variable. Otherwise, we create a new one.
Set `reuse` to `True` when you only want to reuse existing Variables.
Set `reuse` to `False` when you only want to create new Variables.
Set `reuse` to None (the default) or tf.compat.v1.AUTO_REUSE when you want
variables to be created if they don't exist or returned if they do.
In this shim, `reuse` of `False` will be treated as auto-reuse.
If initializer is `None` (the default), the default initializer passed in
the constructor is used. If that one is `None` too, we use a new
Expand Down Expand Up @@ -272,7 +275,7 @@ def custom_getter(getter, name, *args, **kwargs): return getter(name +
# lifted from a function-building graph into the eager context (that's why
# the following clause is not wrapped in an `init_scope`); lifted variables
# are tracked by the graph's `VariableStore`.
if tf.executing_eagerly():
if not reuse:
reuse = tf.compat.v1.AUTO_REUSE

# If a *_ref type is passed in an error would be triggered further down the
Expand Down Expand Up @@ -437,12 +440,6 @@ def _get_single_variable(

if name in self._vars:
# Here we handle the case when returning an existing variable.
if reuse is False: # pylint: disable=g-bool-id-comparison
err_msg = ("Variable %s already exists, disallowed."
" Did you mean to set reuse=True or "
"reuse=tf.AUTO_REUSE in VarScope?" % name)
# ResourceVariables don't have an op associated with so no traceback
raise ValueError(err_msg)
found_var = self._vars[name]
if not shape.is_compatible_with(found_var.get_shape()):
raise ValueError("Trying to share variable %s, but specified shape %s"
Expand Down
92 changes: 23 additions & 69 deletions keras/legacy_tf_layers/variable_scope_shim_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,29 @@ def test_value(value):
test_value(13.) # Variable is reused hereafter.
test_value(17.)

@test_util.run_in_graph_and_eager_modes
@run_inside_wrap_function_in_eager_mode
def testVarScopeGetOrCreateReuseIgnoreFalse(self):
with self.cached_session():

def test_value(value):
x = tf.constant(value)
with tf.compat.v1.variable_scope(
"testVarScopeGetOrCreateReuse_bar",
reuse=False):
_ = tf.compat.v1.assign(tf.compat.v1.get_variable("var", []), x)
# We need to ignore reuse=False in the shim, because the
# code is expected to get rerun each time the user calls the shim.
with tf.compat.v1.variable_scope(
"testVarScopeGetOrCreateReuse_bar",
reuse=False):
_ = tf.compat.v1.get_variable("var", [])
self.assertEqual(value, self.evaluate(x))

test_value(42.) # Variable is created.
test_value(13.) # Variable is reused hereafter.
test_value(17.)

@test_util.run_in_graph_and_eager_modes
@run_inside_wrap_function_in_eager_mode
def testVarOpScope(self):
Expand Down Expand Up @@ -750,75 +773,6 @@ def creator(next_creator, **kwargs):

class VariableScopeMultithreadedTest(tf.test.TestCase):

@test_util.run_in_graph_and_eager_modes
@run_inside_wrap_function_in_eager_mode
def testTwoThreadsDisjointScopeEntry(self):

def thread_fn(i, graph):
with graph.as_default():
with tf.compat.v1.variable_scope("foo"):
if i == 0:
v = tf.compat.v1.get_variable("v", [])
self.assertEqual("foo/v:0", v.name)
else:
# Any thread after the first one should fail to create variable
# with the same name.
with self.assertRaises(ValueError):
tf.compat.v1.get_variable("v", [])

graph = tf.compat.v1.get_default_graph()
threads = [
threading.Thread(target=thread_fn, args=(
i,
graph,
)) for i in range(2)
]

threads[0].start()
# Allow thread 0 to finish before starting thread 1.
threads[0].join()
threads[1].start()
threads[1].join()

@test_util.run_in_graph_and_eager_modes
@run_inside_wrap_function_in_eager_mode
def testTwoThreadsNestedScopeEntry(self):

def thread_fn(i, graph, run_event, pause_event):
with graph.as_default():
with tf.compat.v1.variable_scope("foo"):
if i == 0:
v = tf.compat.v1.get_variable("v", [])
self.assertEqual("foo/v:0", v.name)
else:
# Any thread after the first one should fail to create variable
# with the same name.
with self.assertRaises(ValueError):
tf.compat.v1.get_variable("v", [])
pause_event.set()
run_event.wait()

graph = tf.compat.v1.get_default_graph()
run_events = [threading.Event() for _ in range(2)]
pause_events = [threading.Event() for _ in range(2)]
threads = [
threading.Thread(
target=thread_fn, args=(i, graph, run_events[i], pause_events[i]))
for i in range(2)
]

# Start first thread.
threads[0].start()
pause_events[0].wait()
# Start next thread once the first thread has paused.
threads[1].start()
pause_events[1].wait()
# Resume both threads.
run_events[0].set()
run_events[1].set()
threads[0].join()
threads[1].join()

@test_util.run_in_graph_and_eager_modes
@run_inside_wrap_function_in_eager_mode
def testReenterMainScope(self):
Expand Down

0 comments on commit 0ba7292

Please sign in to comment.