Skip to content

Commit

Permalink
Remove the instance check of _UnreadVariable (private tf api) in Keras.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 351194027
Change-Id: I373b1fe4df9f68d0ceb019f22222f05496a57375
  • Loading branch information
qlzh727 authored and tensorflower-gardener committed Jan 11, 2021
1 parent e8f2243 commit bdea938
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 23 deletions.
5 changes: 0 additions & 5 deletions tensorflow/python/keras/engine/base_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@
from tensorflow.python.module import module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.ops.numpy_ops import np_arrays
from tensorflow.python.ops.ragged import ragged_tensor
Expand Down Expand Up @@ -2877,12 +2876,8 @@ def __setattr__(self, name, value):
# TODO(b/125122625): This won't pick up on any variables added to a
# list/dict after creation.
for val in nest.flatten(value, expand_composites=True):
# TODO(b/126450014): Remove `_UnreadVariable` check here when assign ops
# no longer return True for isinstance Variable checks.
if not isinstance(val, tf_variables.Variable):
continue
if isinstance(val, resource_variable_ops._UnreadVariable): # pylint: disable=protected-access
continue

# Users may add extra weights/variables
# simply by assigning them to attributes (invalid for graph networks)
Expand Down
13 changes: 0 additions & 13 deletions tensorflow/python/keras/engine/base_layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,19 +1139,6 @@ def test_attribute_reassignment(self):
del l.a
self.assertEqual([], l._self_tracked_trackables)

def test_assign_op_not_tracked_as_variable(self):

class LayerWithAssignAttr(base_layer.Layer):

def build(self, input_shape):
self.v = variables.Variable(1.)
self.v_assign = self.v.assign_add(2.)

layer = LayerWithAssignAttr()
layer.build((10, 10))

self.assertEqual([layer.v], layer.variables)

def test_layer_class_not_tracked_as_sublayer(self):
# See https://github.com/tensorflow/tensorflow/issues/27431 for details.

Expand Down
5 changes: 0 additions & 5 deletions tensorflow/python/keras/engine/base_layer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
from tensorflow.python.module import module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import tf_logging
Expand Down Expand Up @@ -2240,12 +2239,8 @@ def __setattr__(self, name, value):
# TODO(b/125122625): This won't pick up on any variables added to a
# list/dict after creation.
for val in nest.flatten(value):
# TODO(b/126450014): Remove `_UnreadVariable` check here when assign ops
# no longer return True for isinstance Variable checks.
if not isinstance(val, tf_variables.Variable):
continue
if isinstance(val, resource_variable_ops._UnreadVariable): # pylint: disable=protected-access
continue

# Users may add extra weights/variables
# simply by assigning them to attributes (invalid for graph networks)
Expand Down

0 comments on commit bdea938

Please sign in to comment.