Skip to content

Commit

Permalink
cntk backend: fix the reversed rnn bug (keras-team#7593)
Browse files Browse the repository at this point in the history
* fix the reversed rnn bug

* udpate error message.

* Fix error msg
  • Loading branch information
souptc authored and fchollet committed Aug 11, 2017
1 parent c2b844b commit 3537381
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions keras/backend/cntk_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,7 +1310,15 @@ def rnn(step_function, inputs, initial_states,
initial.append(s)

need_convert = not has_seq_axis(inputs)
if go_backwards and need_convert is False:
raise NotImplementedError('CNTK Backend: `go_backwards` is not supported with '
'variable-length sequences. Please specify a '
'static length for your sequences.')

if need_convert:
if go_backwards:
inputs = reverse(inputs, 1)

inputs = C.to_sequence(inputs)

j = 0
Expand All @@ -1327,6 +1335,8 @@ def rnn(step_function, inputs, initial_states,
j += 1

if mask is not None and not has_seq_axis(mask):
if go_backwards:
mask = reverse(mask, 1)
if len(int_shape(mask)) == 2:
mask = expand_dims(mask)
mask = C.to_sequence_like(mask, inputs)
Expand All @@ -1339,10 +1349,7 @@ def _recurrence(x, states, m):
place_holders = [C.placeholder(dynamic_axes=x.dynamic_axes) for _ in states]
past_values = []
for s, p in zip(states, place_holders):
past_values.append(
C.sequence.past_value(
p, s) if go_backwards is False else C.sequence.future_value(
p, s))
past_values.append(C.sequence.past_value(p, s))
new_output, new_states = step_function(
x, tuple(past_values) + tuple(constants))
if m is not None:
Expand Down

0 comments on commit 3537381

Please sign in to comment.