Skip to content

Commit

Permalink
Make is_resource_variable() an tf.__internal__ API.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 351249930
Change-Id: I4c8aa09d5584531c723f0f4919cbf5f30080f705
  • Loading branch information
qlzh727 authored and tensorflower-gardener committed Jan 11, 2021
1 parent a0e2499 commit 8633943
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 21 deletions.
14 changes: 0 additions & 14 deletions tensorflow/python/keras/layers/recurrent_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from tensorflow.python.ops import gen_cudnn_rnn_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import sysconfig
Expand Down Expand Up @@ -419,19 +418,6 @@ def __init__(self,
if _use_new_code():
self._defun_wrapper = _DefunWrapper(time_major, go_backwards, 'gru')

def build(self, input_shape):
super(GRU, self).build(input_shape)

if not all(isinstance(v, resource_variable_ops.ResourceVariable)
for v in self.weights):
# Non-resource variables, such as DistributedVariables and
# AutoCastVariables, do not work properly with the implementation
# selector, which is used when cuDNN is used. However, by chance, such
# variables happen to work in LSTM, so this check is only needed for GRU.
# TODO(b/136512020): Make non-resource variables work with the
# implementation selector.
self._could_use_gpu_kernel = False

def call(self, inputs, mask=None, training=None, initial_state=None):
# The input should be dense, padded with zeros. If a ragged input is fed
# into the layer, it is padded and the row lengths are used for masking.
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/python/keras/mixed_precision/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ cuda_py_test(
],
)

tf_py_test(
cuda_py_test(
name = "layer_correctness_test",
size = "medium",
srcs = ["layer_correctness_test.py"],
Expand Down
10 changes: 4 additions & 6 deletions tensorflow/python/keras/tests/tracking_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from tensorflow.python.module import module
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import template
from tensorflow.python.ops import variable_scope
Expand Down Expand Up @@ -273,7 +272,7 @@ def testSaveRestore(self):
# Optimizer slot variables are created when the original variable is
# restored.
self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot))
dummy_var = resource_variable_ops.ResourceVariable([1.])
dummy_var = variables_lib.Variable([1.])
on_create_optimizer.minimize(loss=dummy_var.read_value,
var_list=[dummy_var])
status.assert_existing_objects_matched()
Expand Down Expand Up @@ -459,8 +458,8 @@ class Model(training.Model):

def __init__(self):
super(Model, self).__init__()
self.w = resource_variable_ops.ResourceVariable(0.0)
self.b = resource_variable_ops.ResourceVariable(0.0)
self.w = variables_lib.Variable(0.0)
self.b = variables_lib.Variable(0.0)
self.vars = [self.w, self.b]

def call(self, x):
Expand Down Expand Up @@ -874,8 +873,7 @@ def testLoadFromNameBasedSaver(self):
self._check_sentinels(root)
# Check that there is no error when keys are missing from the name-based
# checkpoint.
root.not_in_name_checkpoint = resource_variable_ops.ResourceVariable(
[1.])
root.not_in_name_checkpoint = variables_lib.Variable([1.])
status = object_saver.restore(save_path)
with self.assertRaises(AssertionError):
status.assert_existing_objects_matched()
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/python/ops/resource_variable_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from tensorflow.python.util import _pywrap_utils
from tensorflow.python.util import compat
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export

acd.register_read_only_resource_op("ReadVariableOp")
acd.register_read_only_resource_op("VariableShape")
Expand Down Expand Up @@ -2211,6 +2212,7 @@ def _from_proto_fn(v, import_scope=None):
from_proto=_from_proto_fn)


@tf_export("__internal__.ops.is_resource_variable", v1=[])
def is_resource_variable(var):
""""Returns True if `var` is to be considered a ResourceVariable."""
return isinstance(var, BaseResourceVariable) or hasattr(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,8 @@ tf_module {
name: "broadcast_weights"
argspec: "args=[\'weights\', \'values\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_resource_variable"
argspec: "args=[\'var\'], varargs=None, keywords=None, defaults=None"
}
}

0 comments on commit 8633943

Please sign in to comment.