Skip to content

Commit

Permalink
Apply layer normalization to LSTMCell and class CoupledInputForgetGat…
Browse files Browse the repository at this point in the history
…eLSTMCell tensorflow#9600 (tensorflow#9839)

* Apply layer normalization to CoupledInputForgetGateLSTMCell. (Review required tensorflow#9600)

* changed variable name _g, _b => _norm_gain, _norm_shift

* Add layer normalization reference.

* Add an unit test that checks the layer normalization to LSTMCell.

* Add unit test verifies LSTM Layer Normalization.

The results of LSTMCell and LayerNormBasicLSTMCell should be same.

* Fix bugs on rnn cells.

* Add LayerNormLSTMCell on contrib.rnn

* Apply changes on rnn_cell_test.
Fix bugs on rnn_cell.
Add layer_norm parameter on _linear function.

* Bug fix : add missing import

* Add custom _linear function inside the LayerNormLSTMCell.

* Sanity check fix : RNNCell => LSTMCell

* Sanity check fix again

* remote state_is_tuple argument

* remove num_unit_shards and num_proj_shards arguments.

* remove state_is_tuple in LayerNormLSTMCell

* remove state_is_tuple in core_rnn_cell_test.py

* fix LayerNormLSTMCell test

* keep rnn_cell_impl.py unmodified.

* @ebrevdo your feedback is applied :)
  • Loading branch information
chris-chris authored and martinwicke committed Nov 6, 2017
1 parent 17ce984 commit 19559c0
Show file tree
Hide file tree
Showing 3 changed files with 402 additions and 6 deletions.
42 changes: 42 additions & 0 deletions tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
from tensorflow.python.framework import test_util
from tensorflow.contrib.rnn.python.ops import rnn_cell as contrib_rnn_cell



# pylint: enable=protected-access
Expand Down Expand Up @@ -358,6 +361,45 @@ def testLSTMCellVariables(self):
self.assertEquals(variables[2].op.name,
"root/lstm_cell/projection/kernel")

def testLSTMCellLayerNorm(self):
with self.test_session() as sess:
num_units = 2
num_proj = 3
batch_size = 1
input_size = 4
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([batch_size, input_size])
c = array_ops.zeros([batch_size, num_units])
h = array_ops.zeros([batch_size, num_proj])
state = rnn_cell_impl.LSTMStateTuple(c, h)
cell = contrib_rnn_cell.LayerNormLSTMCell(
num_units=num_units,
num_proj=num_proj,
forget_bias=1.0,
layer_norm=True,
norm_gain=1.0,
norm_shift=0.0)
g, out_m = cell(x, state)
sess.run([variables_lib.global_variables_initializer()])
res = sess.run([g, out_m], {
x.name: np.ones((batch_size, input_size)),
c.name: 0.1 * np.ones((batch_size, num_units)),
h.name: 0.1 * np.ones((batch_size, num_proj))
})
self.assertEqual(len(res), 2)
# The numbers in results were not calculated, this is mostly just a
# smoke test.
self.assertEqual(res[0].shape, (batch_size, num_proj))
self.assertEqual(res[1][0].shape, (batch_size, num_units))
self.assertEqual(res[1][1].shape, (batch_size, num_proj))
# Different inputs so different outputs and states
for i in range(1, batch_size):
self.assertTrue(
float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) < 1e-6)
self.assertTrue(
float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) < 1e-6)

def testOutputProjectionWrapper(self):
with self.test_session() as sess:
with variable_scope.variable_scope(
Expand Down
44 changes: 44 additions & 0 deletions tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
Expand Down Expand Up @@ -1275,6 +1276,49 @@ def testBasicLSTMCellWithStateTuple(self):
self.assertAllClose(res[2].c, expected_c1, 1e-5)
self.assertAllClose(res[2].h, expected_h1, 1e-5)


def testBasicLSTMCellWithStateTupleLayerNorm(self):
"""The results of LSTMCell and LayerNormBasicLSTMCell
should be same. """
with self.test_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
c0 = array_ops.zeros([1, 2])
h0 = array_ops.zeros([1, 2])
state0 = rnn_cell_impl.LSTMStateTuple(c0, h0)
c1 = array_ops.zeros([1, 2])
h1 = array_ops.zeros([1, 2])
state1 = rnn_cell_impl.LSTMStateTuple(c1, h1)
cell = rnn_cell_impl.MultiRNNCell(
[contrib_rnn_cell.LayerNormLSTMCell(
2,
layer_norm=True,
norm_gain=1.0,
norm_shift=0.0) for _ in range(2)])
h, (s0, s1) = cell(x, (state0, state1))
sess.run([variables.global_variables_initializer()])
res = sess.run([h, s0, s1], {
x.name: np.array([[1., 1.]]),
c0.name: 0.1 * np.asarray([[0, 1]]),
h0.name: 0.1 * np.asarray([[2, 3]]),
c1.name: 0.1 * np.asarray([[4, 5]]),
h1.name: 0.1 * np.asarray([[6, 7]]),
})

expected_h = np.array([[-0.38079708, 0.38079708]])
expected_h0 = np.array([[-0.38079708, 0.38079708]])
expected_c0 = np.array([[-1.0, 1.0]])
expected_h1 = np.array([[-0.38079708, 0.38079708]])
expected_c1 = np.array([[-1.0, 1.0]])

self.assertEqual(len(res), 3)
self.assertAllClose(res[0], expected_h, 1e-5)
self.assertAllClose(res[1].c, expected_c0, 1e-5)
self.assertAllClose(res[1].h, expected_h0, 1e-5)
self.assertAllClose(res[2].c, expected_c1, 1e-5)
self.assertAllClose(res[2].h, expected_h1, 1e-5)

def testBasicLSTMCellWithDropout(self):

def _is_close(x, y, digits=4):
Expand Down
Loading

0 comments on commit 19559c0

Please sign in to comment.