Skip to content

Commit

Permalink
Make LSTMCell use Defuns to speed up static graph builds, add compile…
Browse files Browse the repository at this point in the history
…d flag.

Change: 145703555
  • Loading branch information
ebrevdo authored and tensorflower-gardener committed Jan 26, 2017
1 parent c76c5c4 commit 13a1a9a
Show file tree
Hide file tree
Showing 11 changed files with 391 additions and 161 deletions.
4 changes: 2 additions & 2 deletions configure
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,10 @@ done

if [ "$TF_ENABLE_XLA" == "1" ]; then
# Update Bazel build configuration.
perl -pi -e "s,WITH_XLA_SUPPORT = (False|True),WITH_XLA_SUPPORT = True,s" tensorflow/core/platform/default/build_config.bzl
sed -i -e "s/WITH_XLA_SUPPORT = (False|True)/WITH_XLA_SUPPORT = True/" tensorflow/core/platform/default/build_config_root.bzl
else
# Update Bazel build configuration.
perl -pi -e "s,WITH_XLA_SUPPORT = (False|True),WITH_XLA_SUPPORT = False,s" tensorflow/core/platform/default/build_config.bzl
sed -i -e "s/WITH_XLA_SUPPORT = (False|True)/WITH_XLA_SUPPORT = False/" tensorflow/core/platform/default/build_config_root.bzl
fi


Expand Down
2 changes: 2 additions & 0 deletions tensorflow/contrib/rnn/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ cuda_py_tests(
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
],
xla_enabled = True,
)

cuda_py_tests(
Expand All @@ -91,6 +92,7 @@ cuda_py_tests(
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
],
xla_enabled = True,
)

cuda_py_tests(
Expand Down
165 changes: 156 additions & 9 deletions tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import print_function

import functools
import itertools
import sys

# TODO: #6568 Remove this hack that makes dlopen() not crash.
Expand All @@ -33,20 +34,56 @@

from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
from tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl import _linear as linear
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
from tensorflow.python.util import nest

# pylint: enable=protected-access


def _CreateMultiLSTMCellOps(batch_size, num_units, input_depth,
num_layers, max_time, compiled):
with variable_scope.variable_scope(
"root",
initializer=init_ops.random_uniform_initializer(-0.1, 0.1, seed=2)):
inputs = random_ops.random_uniform(
(max_time, batch_size, input_depth), seed=1)
rnn_cell = core_rnn_cell_impl.MultiRNNCell(
[core_rnn_cell_impl.LSTMCell(num_units, compiled=compiled)
for _ in range(num_layers)])
initial_state = rnn_cell.zero_state(
batch_size=batch_size, dtype=dtypes.float32)
outputs, final_state = rnn.dynamic_rnn(
cell=rnn_cell, inputs=inputs, initial_state=initial_state,
time_major=True)
flat_final_state = nest.flatten(final_state)
trainable_variables = variables_lib.trainable_variables()
outputs_grad = gradients_impl.gradients(
[outputs],
trainable_variables + [inputs] + nest.flatten(initial_state))
final_state_grad = gradients_impl.gradients(
flat_final_state,
trainable_variables + [inputs] + nest.flatten(initial_state))

return {"outputs": outputs,
"final_state": flat_final_state,
"outputs_grad": outputs_grad,
"final_state_grad": final_state_grad}


class RNNCellTest(test.TestCase):

def testLinear(self):
Expand Down Expand Up @@ -117,8 +154,8 @@ def testBasicLSTMCell(self):
x = array_ops.zeros([1, 2])
m = array_ops.zeros([1, 8])
g, out_m = core_rnn_cell_impl.MultiRNNCell(
[core_rnn_cell_impl.BasicLSTMCell(
2, state_is_tuple=False)] * 2,
[core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
for _ in range(2)],
state_is_tuple=False)(x, m)
sess.run([variables_lib.global_variables_initializer()])
res = sess.run(
Expand Down Expand Up @@ -165,7 +202,8 @@ def testBasicLSTMCellStateTupleType(self):
m0 = (array_ops.zeros([1, 2]),) * 2
m1 = (array_ops.zeros([1, 2]),) * 2
cell = core_rnn_cell_impl.MultiRNNCell(
[core_rnn_cell_impl.BasicLSTMCell(2)] * 2, state_is_tuple=True)
[core_rnn_cell_impl.BasicLSTMCell(2) for _ in range(2)],
state_is_tuple=True)
self.assertTrue(isinstance(cell.state_size, tuple))
self.assertTrue(
isinstance(cell.state_size[0], core_rnn_cell_impl.LSTMStateTuple))
Expand Down Expand Up @@ -197,8 +235,8 @@ def testBasicLSTMCellWithStateTuple(self):
m0 = array_ops.zeros([1, 4])
m1 = array_ops.zeros([1, 4])
cell = core_rnn_cell_impl.MultiRNNCell(
[core_rnn_cell_impl.BasicLSTMCell(
2, state_is_tuple=False)] * 2,
[core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
for _ in range(2)],
state_is_tuple=True)
g, (out_m0, out_m1) = cell(x, (m0, m1))
sess.run([variables_lib.global_variables_initializer()])
Expand Down Expand Up @@ -407,7 +445,8 @@ def testMultiRNNCell(self):
x = array_ops.zeros([1, 2])
m = array_ops.zeros([1, 4])
_, ml = core_rnn_cell_impl.MultiRNNCell(
[core_rnn_cell_impl.GRUCell(2)] * 2, state_is_tuple=False)(x, m)
[core_rnn_cell_impl.GRUCell(2) for _ in range(2)],
state_is_tuple=False)(x, m)
sess.run([variables_lib.global_variables_initializer()])
res = sess.run(ml, {
x.name: np.array([[1., 1.]]),
Expand All @@ -416,6 +455,48 @@ def testMultiRNNCell(self):
# The numbers in results were not calculated, this is just a smoke test.
self.assertAllClose(res, [[0.175991, 0.175991, 0.13248, 0.13248]])

def testMultiRNNCellWithLSTMCellAndXLA(self):
# TODO(b/34735319): Don't run this test if XLA is not available.
batch_size = 16
num_units = 32
input_depth = 12
num_layers = 2
max_time = 20

random_seed.set_random_seed(1234)
with self.test_session(graph=ops.Graph()) as sess:
xla_ops = _CreateMultiLSTMCellOps(
batch_size=batch_size, num_units=num_units,
input_depth=input_depth, num_layers=num_layers,
max_time=max_time,
compiled=True)
sess.run([variables_lib.global_variables_initializer()])
xla_results = sess.run(xla_ops)

random_seed.set_random_seed(1234)
with self.test_session(graph=ops.Graph()) as sess:
non_xla_ops = _CreateMultiLSTMCellOps(
batch_size=batch_size, num_units=num_units,
input_depth=input_depth, num_layers=num_layers,
max_time=max_time,
compiled=False)
sess.run([variables_lib.global_variables_initializer()])
non_xla_results = sess.run(non_xla_ops)

self.assertAllClose(non_xla_results["outputs"], xla_results["outputs"])

for xla_value, non_xla_value in zip(
xla_results["final_state"], non_xla_results["final_state"]):
self.assertAllClose(xla_value, non_xla_value)

for xla_g, non_xla_g in zip(
xla_results["outputs_grad"], non_xla_results["outputs_grad"]):
self.assertAllClose(xla_g, non_xla_g)

for xla_g, non_xla_g in zip(
xla_results["final_state_grad"], non_xla_results["final_state_grad"]):
self.assertAllClose(xla_g, non_xla_g)

def testMultiRNNCellWithStateTuple(self):
with self.test_session() as sess:
with variable_scope.variable_scope(
Expand All @@ -427,11 +508,12 @@ def testMultiRNNCellWithStateTuple(self):
# Test incorrectness of state
with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"):
core_rnn_cell_impl.MultiRNNCell(
[core_rnn_cell_impl.GRUCell(2)] * 2,
[core_rnn_cell_impl.GRUCell(2) for _ in range(2)],
state_is_tuple=True)(x, m_bad)

_, ml = core_rnn_cell_impl.MultiRNNCell(
[core_rnn_cell_impl.GRUCell(2)] * 2, state_is_tuple=True)(x, m_good)
[core_rnn_cell_impl.GRUCell(2) for _ in range(2)],
state_is_tuple=True)(x, m_good)

sess.run([variables_lib.global_variables_initializer()])
res = sess.run(ml, {
Expand Down Expand Up @@ -490,7 +572,7 @@ def testBasicRNNCellMatch(self):
self.assertAllClose(res[1], res[3])


def basic_rnn_cell(inputs, state, num_units, scope=None):
def basic_rnn_cell(inputs, state, num_units, scope=None): # pylint: disable=invalid-name
if state is None:
if inputs is not None:
batch_size = inputs.get_shape()[0]
Expand All @@ -512,5 +594,70 @@ def basic_rnn_cell(inputs, state, num_units, scope=None):
return output, output


class BenchmarkLSTMCellXLA(test.Benchmark):

def benchmarkDynamicRNNWithMultiLSTMCell(self):
num_layers = 3
max_time = 50
print("benchmarkDynamicRNNWithMultiLSTMCell")
print("\t" +
"\t".join(["inter_th", "intra_th",
"batch_size", "num_units", "input_depth", "device",
"compiled", "wall_time"]))

warmup_run = True
for (threads,
device,
num_units,
batch_size,
input_depth,
compiled) in itertools.product(
[{"inter": 0, "intra": 0}, {"inter": 1, "intra": 4}],
["cpu", "gpu"],
[32, 512],
[1, 32, 256],
[32, 512],
[False, True]):
if threads["inter"] != 0:
# We only care about testing inter/intra op limitations on
# CPU with small batch size, to mimic embedded devices.
if device != "cpu" or batch_size != 1:
continue
if device == "cpu" and batch_size > 32:
continue
random_seed.set_random_seed(1234)
config = config_pb2.ConfigProto(
inter_op_parallelism_threads=threads["inter"],
intra_op_parallelism_threads=threads["intra"],
allow_soft_placement=False)
with session.Session(config=config, graph=ops.Graph()) as sess:
with ops.device("/%s:0" % device):
ops_dict = _CreateMultiLSTMCellOps(
batch_size=batch_size, num_units=num_units,
input_depth=input_depth, num_layers=num_layers,
max_time=max_time,
compiled=compiled)
sess.run([variables_lib.global_variables_initializer()])
all_ops = nest.flatten(ops_dict.values())
all_ops_group = control_flow_ops.group(*all_ops)
name_suffix = (
"inter_th_%d_intra_th_%d_bs_%d_units_%d_inputdepth_%d"
"_device_%s_xla_%s" % (
threads["inter"], threads["intra"],
batch_size, num_units, input_depth, device, compiled))
if warmup_run:
self.run_op_benchmark(
sess, all_ops_group, min_iters=30, name="ignore_warmup")
warmup_run = False
benchmark_results = self.run_op_benchmark(
sess, all_ops_group, min_iters=30,
name="benchmarkDynamicRNNWithMultiLSTMCell_%s" % name_suffix)
print("\t" +
"\t".join(["%s" % x for x in [
threads["inter"], threads["intra"],
batch_size, num_units, input_depth, device, compiled,
benchmark_results["wall_time"]]]))


if __name__ == "__main__":
test.main()
Loading

0 comments on commit 13a1a9a

Please sign in to comment.