Skip to content

Commit

Permalink
Fix error messages in variables.py
Browse files Browse the repository at this point in the history
I've focused on the messages that are relevant to resource variables or migration rather than the 1.x-only symbols / reference variables.

Removes an error message for tf.Variable.__iter__. It now just reads from the variable and iterates over the result, which is what you'd expect it to do. I think that's better than trying to fix an error message that shouldn't exist in the first place.

Fixes tensorflow#23185.

PiperOrigin-RevId: 395965288
Change-Id: Id5ff6ed2545c4ef5e67acccf3fcbc4094d289e4f
  • Loading branch information
allenlavoie authored and tensorflower-gardener committed Sep 10, 2021
1 parent 28703de commit 9345aee
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 22 deletions.
3 changes: 1 addition & 2 deletions tensorflow/python/distribute/values_v2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,7 @@ def testEqualityGraph(self):

def testIteration(self):
v = self.create_variable([1.])
with self.assertRaises(TypeError):
v.__iter__()
self.assertEqual([1.], list(iter(v)))

def testProperties(self):
v = self.create_variable()
Expand Down
6 changes: 6 additions & 0 deletions tensorflow/python/kernel_tests/resource_variable_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1600,5 +1600,11 @@ def _create_and_delete_variable():
checker.report()
checker.assert_no_leak_if_all_possibly_except_one()

@test_util.run_v2_only
def testIterateVariable(self):
v = variables.Variable([1., 2.])
self.assertAllClose([1., 2.], list(iter(v)))


if __name__ == "__main__":
test.main()
17 changes: 14 additions & 3 deletions tensorflow/python/kernel_tests/variables_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,25 @@ def testCyclicInitializer(self):
"test", cyclic)
self.assertIs(initial_value, cyclic)

def testIterable(self):
with self.assertRaisesRegex(TypeError, "not iterable"):
@test_util.run_deprecated_v1
def testIterableV1(self):
with self.assertRaisesRegex(TypeError, "not allowed in Graph"):
for _ in variables.Variable(0.0):
pass
with self.assertRaisesRegex(TypeError, "not iterable"):
with self.assertRaisesRegex(TypeError, "not allowed in Graph"):
for _ in variables.Variable([0.0, 1.0]):
pass

@test_util.run_v2_only
def testIterableV2(self):
with self.assertRaisesRegex(TypeError, "scalar tensor"):
for _ in variables.Variable(0.0):
pass
values = []
for v in variables.Variable([0.0, 1.0]):
values.append(v)
self.assertAllClose([0., 1.], values)

@test_util.run_deprecated_v1
def testAssignments(self):
with self.cached_session():
Expand Down
27 changes: 10 additions & 17 deletions tensorflow/python/ops/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,8 +1038,8 @@ def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): # pylint
_ = name
if dtype and not dtype.is_compatible_with(v.dtype):
raise ValueError(
"Incompatible type conversion requested to type '%s' for variable "
"of type '%s'" % (dtype.name, v.dtype.name))
f"Incompatible type conversion requested to type '{dtype.name}' for "
f"variable of type '{v.dtype.name}' (Variable: {v}).")
if as_ref:
return v._ref() # pylint: disable=protected-access
else:
Expand Down Expand Up @@ -1082,8 +1082,9 @@ def _run_op(a, *args, **kwargs):

def __hash__(self):
if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access
raise TypeError("Variable is unhashable. "
"Instead, use tensor.ref() as the key.")
raise TypeError(
"Variable is unhashable. "
f"Instead, use variable.ref() as the key. (Variable: {self})")
else:
return id(self)

Expand All @@ -1106,18 +1107,8 @@ def __ne__(self, other):
return self is not other

def __iter__(self):
"""Dummy method to prevent iteration.
Do not call.
NOTE(mrry): If we register __getitem__ as an overloaded operator,
Python will valiantly attempt to iterate over the variable's Tensor from 0
to infinity. Declaring this method prevents this unintended behavior.
Raises:
TypeError: when invoked.
"""
raise TypeError("'Variable' object is not iterable.")
"""When executing eagerly, iterates over the value of the variable."""
return iter(self.read_value())

# NOTE(mrry): This enables the Variable's overloaded "right" binary
# operators to run when the left operand is an ndarray, because it
Expand Down Expand Up @@ -1783,7 +1774,9 @@ def _init_from_args(self,
# Ensure that we weren't lifted into the eager context.
if context.executing_eagerly():
raise RuntimeError(
"RefVariable not supported when eager execution is enabled. ")
"Reference variables are not supported when eager execution is "
"enabled. Please run `tf.compat.v1.enable_resource_variables()` to "
"switch to resource variables.")
with ops.name_scope(name, "Variable",
[] if init_from_fn else [initial_value]) as name:

Expand Down

0 comments on commit 9345aee

Please sign in to comment.