Skip to content

Commit

Permalink
Don't cache zero tensors in graph at all
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 205885372
  • Loading branch information
Akshay Modi authored and tensorflower-gardener committed Jul 24, 2018
1 parent ee0bd6e commit 57d051e
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 31 deletions.
16 changes: 11 additions & 5 deletions tensorflow/python/eager/backprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,15 +599,18 @@ def _fast_fill(value, shape, dtype):


def _zeros(shape, dtype):
"""Wraps array_ops.zeros to cache last zero for a given shape and dtype."""
device = context.context().device_name
"""Helper to return (possibly cached) zero tensors in eager mode."""
if dtype == dtypes.variant:
# TODO(apassos): need to save enough information about variant tensors to do
# a zeros
return None
# pylint: disable=protected-access
cache_key = shape, dtype, device, context.context()._eager_context.mode
# pylint: enable=protected-access

ctx = context.context()
if not ctx.executing_eagerly():
return array_ops.zeros(shape, dtype)

device = ctx.device_name
cache_key = shape, dtype, device
cached = _zeros_cache.get(cache_key)
if cached is None:
cached = _fast_fill(0, shape, dtype)
Expand All @@ -616,6 +619,9 @@ def _zeros(shape, dtype):


def _ones(shape, dtype):
if not context.context().executing_eagerly():
return array_ops.ones(shape, dtype)

if shape == (): # pylint: disable=g-explicit-bool-comparison
return constant_op.constant(1, dtype=dtype)
return _fast_fill(1, shape, dtype)
Expand Down
43 changes: 17 additions & 26 deletions tensorflow/python/eager/backprop_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,32 +925,23 @@ def fn(x, y):
'did you forget to return a value from fn?'):
val_and_grads_fn(x, y)

def testZerosCacheDoesntLeakAcrossModes(self):
with ops.Graph().as_default():
t = random_ops.random_normal(shape=[100, 2])
x = random_ops.random_normal(shape=[100, 4])
dy = random_ops.random_normal(shape=[100, 4])
with backprop.GradientTape() as gradient_tape:
gradient_tape.watch(x)
x1, _ = array_ops.split(x, num_or_size_splits=2, axis=1)
y1 = x1 ** 2.
y = array_ops.concat([y1, t], axis=1)

dx = gradient_tape.gradient(y, x, output_gradients=dy)
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
sess.run(dx)

t = random_ops.random_normal(shape=[100, 2])
x = random_ops.random_normal(shape=[100, 4])
dy = random_ops.random_normal(shape=[100, 4])
with backprop.GradientTape() as gradient_tape:
gradient_tape.watch(x)
x1, _ = array_ops.split(x, num_or_size_splits=2, axis=1)
y1 = x1 ** 2.
y = array_ops.concat([y1, t], axis=1)

dx = gradient_tape.gradient(y, x, output_gradients=dy)
def testZerosCacheDoesntLeakAcrossGraphs(self):
with context.graph_mode():
def get_grad():
with ops.Graph().as_default(), self.test_session():
t = constant_op.constant(1, dtype=dtypes.float32, shape=(10, 4))
x = constant_op.constant(2, dtype=dtypes.float32, shape=(10, 4))
with backprop.GradientTape() as gt:
tape.watch(x)
x1, _ = array_ops.split(x, num_or_size_splits=2, axis=1)
y1 = x1**2
y = array_ops.concat([y1, t], axis=1)
return self.evaluate(gt.gradient(y, x))

grad1 = get_grad()
grad2 = get_grad()

self.assertAllEqual(grad1, grad2)


if __name__ == '__main__':
Expand Down

0 comments on commit 57d051e

Please sign in to comment.