Skip to content

Commit

Permalink
Set default training value to False when exporting layer/model call…
Browse files Browse the repository at this point in the history
…s to SavedModel.

PiperOrigin-RevId: 262481484
  • Loading branch information
k-w-w authored and tensorflower-gardener committed Aug 9, 2019
1 parent 9c04d7c commit b1e40a2
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 1 deletion.
51 changes: 51 additions & 0 deletions tensorflow/python/keras/saving/saved_model/saved_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from tensorflow.python.platform import test
from tensorflow.python.saved_model import load as tf_load
from tensorflow.python.saved_model import save as tf_save
from tensorflow.python.util import tf_inspect


class LayerWithLearningPhase(keras.engine.base_layer.Layer):
Expand Down Expand Up @@ -384,6 +385,56 @@ def parse_and_predict(examples):
self.assertAllClose(model.predict(input_arr), outputs['predictions'])
self.assertAllClose(model.layers[0](input_arr), outputs['layer_1_outputs'])

def testTrainingDefaults(self):
def assert_training_default(fn, default_value):
arg_spec = tf_inspect.getfullargspec(fn)
index = len(arg_spec.args) - arg_spec.args.index('training')
self.assertEqual(arg_spec.defaults[-index], default_value)

class LayerWithTrainingRequiredArg(keras.engine.base_layer.Layer):

def call(self, inputs, training):
return tf_utils.smart_cond(
training, lambda: inputs * 0, lambda: array_ops.identity(inputs))

class LayerWithTrainingDefaultTrue(keras.engine.base_layer.Layer):

def call(self, inputs, training=True):
return tf_utils.smart_cond(
training, lambda: inputs * 0, lambda: array_ops.identity(inputs))

class Model(keras.models.Model):

def __init__(self):
super(Model, self).__init__()
self.layer_with_training_default_none = LayerWithLearningPhase()
self.layer_with_training_default_true = LayerWithTrainingDefaultTrue()
self.layer_with_required_training_arg = LayerWithTrainingRequiredArg()

def call(self, inputs):
x = self.layer_with_training_default_none(inputs)
x += self.layer_with_training_default_true(inputs)
x += self.layer_with_required_training_arg(inputs, False)
return x

model = Model()
# Build and set model inputs
model.predict(np.ones([1, 3]).astype('float32'))
saved_model_dir = self._save_model_dir()
model.save(saved_model_dir, save_format='tf')
load = tf_load.load(saved_model_dir)

assert_training_default(load.__call__, False)
assert_training_default(
load.layer_with_training_default_none.__call__, False)
assert_training_default(
load.layer_with_training_default_true.__call__, True)

# Assert that there are no defaults for layer with required training arg
arg_spec = tf_inspect.getfullargspec(
load.layer_with_required_training_arg.__call__)
self.assertFalse(arg_spec.defaults) # defaults is None or empty


class TestLayerCallTracing(test.TestCase):

Expand Down
11 changes: 10 additions & 1 deletion tensorflow/python/keras/saving/saved_model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,18 +113,27 @@ def replace_training_and_call(training):
# Create arg spec for decorated function. If 'training' is not defined in the
# args of the original arg spec, then add it to kwonlyargs.
arg_spec = tf_inspect.getfullargspec(original_call)
defaults = list(arg_spec.defaults) if arg_spec.defaults is not None else []

kwonlyargs = arg_spec.kwonlyargs
kwonlydefaults = arg_spec.kwonlydefaults or {}
# Add training arg if it does not exist, or set the default training value.
if 'training' not in arg_spec.args:
kwonlyargs.append('training')
kwonlydefaults['training'] = default_training_value
else:
index = arg_spec.args.index('training')
training_default_index = len(arg_spec.args) - index
if (arg_spec.defaults and
len(arg_spec.defaults) >= training_default_index and
defaults[-training_default_index] is None):
defaults[-training_default_index] = default_training_value

decorator_argspec = tf_inspect.FullArgSpec(
args=arg_spec.args,
varargs=arg_spec.varargs,
varkw=arg_spec.varkw,
defaults=arg_spec.defaults,
defaults=defaults,
kwonlyargs=kwonlyargs,
kwonlydefaults=kwonlydefaults,
annotations=arg_spec.annotations)
Expand Down

0 comments on commit b1e40a2

Please sign in to comment.