Skip to content

Commit

Permalink
Enable stateful RNN with cntk backend (keras-team#7272)
Browse files Browse the repository at this point in the history
* support stateful lstm by user function

* use variable to workaround keras naming issue; fix style issue

* revert useless code;update test

* add doc string

* update doc string

* Style fix
  • Loading branch information
souptc authored and fchollet committed Jul 8, 2017
1 parent f3f6e4f commit 09b97e9
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 59 deletions.
87 changes: 77 additions & 10 deletions keras/backend/cntk_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def constant(value, dtype=None, shape=None, name=None):
const = C.constant(np_value,
dtype=dtype,
name=_prepare_name(name, 'constant'))
const._keras_shape = shape
const._keras_shape = const.shape
const._uses_learning_phase = False
return const

Expand Down Expand Up @@ -1216,6 +1216,13 @@ def rnn(step_function, inputs, initial_states,
if num_time_step is None and not has_seq_axis(inputs):
num_time_step = inputs.shape[0]

initial = []
for s in initial_states:
if _get_dynamic_axis_num(s) == 0:
initial.append(C.user_function(ConvertToBatch(s)))
else:
initial.append(s)

need_convert = not has_seq_axis(inputs)
if need_convert:
inputs = C.to_sequence(inputs)
Expand All @@ -1233,7 +1240,7 @@ def rnn(step_function, inputs, initial_states,
constants[j] = C.sequence.broadcast_as(constants[j], inputs)
j += 1

states = tuple(initial_states)
states = tuple(initial)

with C.default_options(axis_offset=1):
def _recurrence(x, states):
Expand All @@ -1256,20 +1263,21 @@ def _recurrence(x, states):

final_output, final_states = _recurrence(inputs, states)
last_output = C.sequence.last(final_output)
last_states = final_states
last_states = [C.sequence.last(s) for s in final_states]

if need_convert:
final_output = C.sequence.unpack(final_output, 0, no_mask_output=True)
last_states = [
C.sequence.unpack(
s, 0, no_mask_output=True) for s in last_states]
if num_time_step is not None and num_time_step is not C.FreeDimension:
final_output = _reshape_sequence(final_output, num_time_step)
last_states = [
_reshape_sequence(
_, num_time_step) for _ in last_states]

return last_output, final_output, last_states
f_stats = []
for l_s, i_s in zip(last_states, initial_states):
if _get_dynamic_axis_num(i_s) == 0 and _get_dynamic_axis_num(l_s) == 1:
f_stats.append(C.user_function(ConvertToStatic(l_s, batch_size=i_s.shape[0])))
else:
f_stats.append(l_s)

return last_output, final_output, f_stats


def has_seq_axis(x):
Expand Down Expand Up @@ -2136,6 +2144,65 @@ def backward(self, state, root_gradients):
(num_old_batch,) + self.from_shape))


class ConvertToBatch(C.ops.functions.UserFunction):
"""Converts input first axis to CNTK batch axis.
We may introduce this operation in CNTK native
implementation later.
# Arguments
inputs: a cntk variable (parameter/constant)
name: name of this node
"""

def __init__(self, input, name='convert_to_batch'):
super(ConvertToBatch, self).__init__([input], as_numpy=False, name=name)

def infer_outputs(self):
batch_axis = C.Axis.default_batch_axis()
return [
C.output_variable(
self.inputs[0].shape[1:],
self.inputs[0].dtype,
[batch_axis])]

def forward(self, arguments, device=None, outputs_to_retain=None):
return None, C.cntk_py.Value(arguments.data())

def backward(self, state, root_gradients):
return C.cntk_py.Value(root_gradients.data())


class ConvertToStatic(C.ops.functions.UserFunction):
"""Converts input first axis to CNTK static axis.
We may introduce this operation in CNTK native
implementation later.
# Arguments
inputs: a cntk tensor which has batch axis
batch_size: size of batch axis.
name: name of this node.
"""

def __init__(self, input, batch_size, name='convert_to_static'):
super(ConvertToStatic, self).__init__([input], as_numpy=False, name=name)
self.target_shape = (batch_size,) + input.shape

def infer_outputs(self):
return [
C.output_variable(
self.target_shape,
self.inputs[0].dtype,
[])]

def forward(self, arguments, device=None, outputs_to_retain=None):
return None, C.cntk_py.Value(arguments.data())

def backward(self, state, root_gradients):
return C.cntk_py.Value(root_gradients.data())


class LambdaFunc(C.ops.functions.UserFunction):
def __init__(self,
arg,
Expand Down
2 changes: 0 additions & 2 deletions keras/layers/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,6 @@ def __init__(self, return_sequences=False,
self.return_sequences = return_sequences
self.return_state = return_state
self.go_backwards = go_backwards
if K.backend() == 'cntk' and stateful:
raise ValueError('Stateful RNN is not currently supported with CNTK.')

self.stateful = stateful
self.unroll = unroll
Expand Down
79 changes: 40 additions & 39 deletions tests/keras/layers/convolutional_recurrent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,46 +43,47 @@ def test_convolutional_recurrent():
if data_format == 'channels_first' or return_sequences:
continue

# cntk doesn't support statefulness on LSTM yet, will enable it on cntk later
# Tests for statefulness
model = Sequential()
kwargs = {'data_format': data_format,
'return_sequences': return_sequences,
'filters': filters,
'kernel_size': (num_row, num_col),
'stateful': True,
'batch_input_shape': inputs.shape,
'padding': 'same'}
layer = convolutional_recurrent.ConvLSTM2D(**kwargs)

model.add(layer)
model.compile(optimizer='sgd', loss='mse')
out1 = model.predict(np.ones_like(inputs))

# train once so that the states change
model.train_on_batch(np.ones_like(inputs),
np.random.random(out1.shape))
out2 = model.predict(np.ones_like(inputs))

# if the state is not reset, output should be different
assert(out1.max() != out2.max())

# check that output changes after states are reset
# (even though the model itself didn't change)
layer.reset_states()
out3 = model.predict(np.ones_like(inputs))
assert(out2.max() != out3.max())

# check that container-level reset_states() works
model.reset_states()
out4 = model.predict(np.ones_like(inputs))
assert_allclose(out3, out4, atol=1e-5)

# check that the call to `predict` updated the states
out5 = model.predict(np.ones_like(inputs))
assert(out4.max() != out5.max())

# cntk doesn't support eval convolution with static
# variable, will enable it later
if K.backend() != 'cntk':
# Tests for statefulness
model = Sequential()
kwargs = {'data_format': data_format,
'return_sequences': return_sequences,
'filters': filters,
'kernel_size': (num_row, num_col),
'stateful': True,
'batch_input_shape': inputs.shape,
'padding': 'same'}
layer = convolutional_recurrent.ConvLSTM2D(**kwargs)

model.add(layer)
model.compile(optimizer='sgd', loss='mse')
out1 = model.predict(np.ones_like(inputs))

# train once so that the states change
model.train_on_batch(np.ones_like(inputs),
np.random.random(out1.shape))
out2 = model.predict(np.ones_like(inputs))

# if the state is not reset, output should be different
assert(out1.max() != out2.max())

# check that output changes after states are reset
# (even though the model itself didn't change)
layer.reset_states()
out3 = model.predict(np.ones_like(inputs))
assert(out2.max() != out3.max())

# check that container-level reset_states() works
model.reset_states()
out4 = model.predict(np.ones_like(inputs))
assert_allclose(out3, out4, atol=1e-5)

# check that the call to `predict` updated the states
out5 = model.predict(np.ones_like(inputs))
assert(out4.max() != out5.max())

# check regularizers
kwargs = {'data_format': data_format,
'return_sequences': return_sequences,
Expand Down
10 changes: 2 additions & 8 deletions tests/keras/layers/recurrent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def test_dynamic_behavior(layer_class):


@rnn_test
@pytest.mark.skipif(K.backend() == 'cntk', reason='Stateful is not supported with CNTK')
def test_stateful_invalid_use(layer_class):
layer = layer_class(units,
stateful=True,
Expand Down Expand Up @@ -101,7 +100,7 @@ def test_implementation_mode(layer_class):

@rnn_test
@pytest.mark.skipif((K.backend() == 'cntk'),
reason="cntk does not support stateful RNN yet")
reason="cntk does not support mask on RNN yet")
def test_statefulness(layer_class):
model = Sequential()
model.add(embeddings.Embedding(embedding_num, embedding_dim,
Expand Down Expand Up @@ -197,8 +196,7 @@ def test_masking_layer():

@rnn_test
def test_from_config(layer_class):
# cntk does not support stateful yet.
stateful_flags = (False, True) if K.backend() != 'cntk' else (False,)
stateful_flags = (False, True)
for stateful in stateful_flags:
l1 = layer_class(units=1, stateful=stateful)
l2 = layer_class.from_config(l1.get_config())
Expand Down Expand Up @@ -249,8 +247,6 @@ def test_specify_initial_state_non_keras_tensor(layer_class):


@rnn_test
@pytest.mark.skipif((K.backend() == 'cntk'),
reason="cntk does not support stateful RNN yet")
def test_reset_states_with_values(layer_class):
num_states = 2 if layer_class is recurrent.LSTM else 1

Expand Down Expand Up @@ -299,8 +295,6 @@ def test_specify_state_with_masking(layer_class):


@rnn_test
@pytest.mark.skipif((K.backend() == 'cntk'),
reason="cntk does not support stateful RNN yet")
def test_return_state(layer_class):
num_states = 2 if layer_class is recurrent.LSTM else 1

Expand Down

0 comments on commit 09b97e9

Please sign in to comment.