Skip to content

Commit e351922

Browse files
committed
Add masking to GRU and LSTM
1 parent 3f3a1f0 commit e351922

File tree

2 files changed

+53
-15
lines changed

2 files changed

+53
-15
lines changed

keras/layers/recurrent.py

+28-14
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ class GRU(Layer):
198198
def __init__(self, input_dim, output_dim=128,
199199
init='glorot_uniform', inner_init='orthogonal',
200200
activation='sigmoid', inner_activation='hard_sigmoid',
201-
weights=None, truncate_gradient=-1, return_sequences=False):
201+
weights=None, truncate_gradient=-1, return_sequences=False, mask_val=default_mask_val):
202202

203203
super(GRU,self).__init__()
204204
self.input_dim = input_dim
@@ -211,6 +211,7 @@ def __init__(self, input_dim, output_dim=128,
211211
self.activation = activations.get(activation)
212212
self.inner_activation = activations.get(inner_activation)
213213
self.input = T.tensor3()
214+
self.mask_val = shared_scalar(default_mask_val)
214215

215216
self.W_z = self.init((self.input_dim, self.output_dim))
216217
self.U_z = self.inner_init((self.output_dim, self.output_dim))
@@ -234,29 +235,35 @@ def __init__(self, input_dim, output_dim=128,
234235
self.set_weights(weights)
235236

236237
def _step(self,
237-
xz_t, xr_t, xh_t,
238+
xz_t, xr_t, xh_t, mask_tm1,
238239
h_tm1,
239240
u_z, u_r, u_h):
240-
z = self.inner_activation(xz_t + T.dot(h_tm1, u_z))
241-
r = self.inner_activation(xr_t + T.dot(h_tm1, u_r))
242-
hh_t = self.activation(xh_t + T.dot(r * h_tm1, u_h))
243-
h_t = z * h_tm1 + (1 - z) * hh_t
241+
h_mask_tm1 = mask_tm1 * h_tm1
242+
z = self.inner_activation(xz_t + T.dot(h_mask_tm1, u_z))
243+
r = self.inner_activation(xr_t + T.dot(h_mask_tm1, u_r))
244+
hh_t = self.activation(xh_t + T.dot(r * h_mask_tm1, u_h))
245+
h_t = z * h_mask_tm1 + (1 - z) * hh_t
246+
#return theano.printing.Print("h_t")(h_t)
244247
return h_t
245248

246249
def get_output(self, train):
247250
X = self.get_input(train)
248251
X = X.dimshuffle((1,0,2))
252+
mask, padded_mask = get_mask(X, self.mask_val, steps_back=1)
249253

250254
x_z = T.dot(X, self.W_z) + self.b_z
251255
x_r = T.dot(X, self.W_r) + self.b_r
252256
x_h = T.dot(X, self.W_h) + self.b_h
253257
outputs, updates = theano.scan(
254258
self._step,
255-
sequences=[x_z, x_r, x_h],
259+
sequences=[x_z, x_r, x_h, padded_mask],
256260
outputs_info=T.unbroadcast(alloc_zeros_matrix(X.shape[1], self.output_dim), 1),
257261
non_sequences=[self.U_z, self.U_r, self.U_h],
258262
truncate_gradient=self.truncate_gradient
259263
)
264+
265+
outputs = mask * outputs + (1 - mask) * self.mask_val
266+
260267
if self.return_sequences:
261268
return outputs.dimshuffle((1,0,2))
262269
return outputs[-1]
@@ -302,13 +309,14 @@ class LSTM(Layer):
302309
def __init__(self, input_dim, output_dim=128,
303310
init='glorot_uniform', inner_init='orthogonal',
304311
activation='tanh', inner_activation='hard_sigmoid',
305-
weights=None, truncate_gradient=-1, return_sequences=False):
312+
weights=None, truncate_gradient=-1, return_sequences=False, mask_val=default_mask_val):
306313

307314
super(LSTM,self).__init__()
308315
self.input_dim = input_dim
309316
self.output_dim = output_dim
310317
self.truncate_gradient = truncate_gradient
311318
self.return_sequences = return_sequences
319+
self.mask_val = shared_scalar(mask_val)
312320

313321
self.init = initializations.get(init)
314322
self.inner_init = initializations.get(inner_init)
@@ -343,19 +351,23 @@ def __init__(self, input_dim, output_dim=128,
343351
self.set_weights(weights)
344352

345353
def _step(self,
346-
xi_t, xf_t, xo_t, xc_t,
354+
xi_t, xf_t, xo_t, xc_t, mask_tm1,
347355
h_tm1, c_tm1,
348356
u_i, u_f, u_o, u_c):
349-
i_t = self.inner_activation(xi_t + T.dot(h_tm1, u_i))
350-
f_t = self.inner_activation(xf_t + T.dot(h_tm1, u_f))
351-
c_t = f_t * c_tm1 + i_t * self.activation(xc_t + T.dot(h_tm1, u_c))
352-
o_t = self.inner_activation(xo_t + T.dot(h_tm1, u_o))
357+
h_mask_tm1 = mask_tm1 * h_tm1
358+
c_mask_tm1 = mask_tm1 * c_tm1
359+
360+
i_t = self.inner_activation(xi_t + T.dot(h_mask_tm1, u_i))
361+
f_t = self.inner_activation(xf_t + T.dot(h_mask_tm1, u_f))
362+
c_t = f_t * c_mask_tm1 + i_t * self.activation(xc_t + T.dot(h_mask_tm1, u_c))
363+
o_t = self.inner_activation(xo_t + T.dot(h_mask_tm1, u_o))
353364
h_t = o_t * self.activation(c_t)
354365
return h_t, c_t
355366

356367
def get_output(self, train):
357368
X = self.get_input(train)
358369
X = X.dimshuffle((1,0,2))
370+
mask, padded_mask = get_mask(X, self.mask_val, steps_back=1)
359371

360372
xi = T.dot(X, self.W_i) + self.b_i
361373
xf = T.dot(X, self.W_f) + self.b_f
@@ -364,14 +376,16 @@ def get_output(self, train):
364376

365377
[outputs, memories], updates = theano.scan(
366378
self._step,
367-
sequences=[xi, xf, xo, xc],
379+
sequences=[xi, xf, xo, xc, padded_mask],
368380
outputs_info=[
369381
T.unbroadcast(alloc_zeros_matrix(X.shape[1], self.output_dim), 1),
370382
T.unbroadcast(alloc_zeros_matrix(X.shape[1], self.output_dim), 1)
371383
],
372384
non_sequences=[self.U_i, self.U_f, self.U_o, self.U_c],
373385
truncate_gradient=self.truncate_gradient
374386
)
387+
388+
outputs = mask * outputs + (1 - mask) * self.mask_val
375389
if self.return_sequences:
376390
return outputs.dimshuffle((1,0,2))
377391
return outputs[-1]

tests/manual/check_masked_recurrent.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from keras.models import Sequential
77
from keras.layers.core import Dense, Activation, Merge, Dropout, TimeDistributedDense
88
from keras.layers.embeddings import Embedding
9-
from keras.layers.recurrent import SimpleRNN, SimpleDeepRNN
9+
from keras.layers.recurrent import SimpleRNN, SimpleDeepRNN, LSTM, GRU
1010
from keras.layers.core import default_mask_val
1111
import theano
1212

@@ -15,6 +15,7 @@
1515
# (nb_samples, timesteps, dimensions)
1616
X = np.random.random_integers(1, 4, size=(500000,15))
1717

18+
print("About to compile the first model")
1819
model = Sequential()
1920
model.add(Embedding(5, 4, zero_is_mask=True))
2021
model.add(TimeDistributedDense(4,4)) # obviously this is redundant. Just testing.
@@ -108,3 +109,26 @@
108109
if score < uniform_score*0.9:
109110
raise Exception('Somehow learned to copy timestep 0 despite masking 1, score %f' % score)
110111

112+
113+
# Another testing approach, just initialize models and make sure that prepending zeros doesn't affect
114+
# their output
115+
print("About to compile the second model")
116+
model2 = Sequential()
117+
model2.add(Embedding(5, 4, zero_is_mask=True))
118+
model2.add(TimeDistributedDense(4,4)) # obviously this is redundant. Just testing.
119+
model2.add(LSTM(4,4, return_sequences=True))
120+
model2.add(GRU(4,4, activation='softmax', return_sequences=True))
121+
model2.add(SimpleDeepRNN(4,4, depth=2, activation='relu', return_sequences=True))
122+
model2.add(SimpleRNN(4,4, activation='relu', return_sequences=False))
123+
model2.compile(loss='categorical_crossentropy',
124+
optimizer='rmsprop', theano_mode=theano.compile.mode.FAST_RUN)
125+
print("Compiled model2")
126+
127+
X2 = np.random.random_integers(1, 4, size=(1,5))
128+
ref = model2.predict(X2)
129+
for pre_zeros in range(1,10):
130+
padded = np.concatenate((np.zeros((1, pre_zeros)), X2), axis=1)
131+
pred = model2.predict(padded)
132+
if not np.allclose(ref, pred):
133+
raise Exception("Different result after left-padding %d zeros. Ref: %s, Pred: %s" % (pre_zeros, ref, pred))
134+

0 commit comments

Comments
 (0)