Skip to content

Commit 271e2b7

Browse files
committed
Add masking support to the activation layer
1 parent 2384aa4 commit 271e2b7

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

keras/activations.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@ def softmax(x):
1111
def time_distributed_softmax(x, mask_val=default_mask_val):
1212
xshape = x.shape
1313
X = x.reshape((xshape[0] * xshape[1], xshape[2]))
14-
mask = get_mask(X, mask_val)
15-
r = mask * T.nnet.softmax(X) + (1 - mask) * mask_val
16-
return r.reshape(xshape)
14+
return T.nnet.softmax(X).reshape(xshape)
1715

1816
def softplus(x):
1917
return T.nnet.softplus(x)

keras/layers/core.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -168,15 +168,17 @@ class Activation(Layer):
168168
'''
169169
Apply an activation function to an output.
170170
'''
171-
def __init__(self, activation, target=0, beta=0.1):
171+
def __init__(self, activation, target=0, beta=0.1, mask_val=default_mask_val):
172172
super(Activation,self).__init__()
173173
self.activation = activations.get(activation)
174174
self.target = target
175175
self.beta = beta
176+
self.mask_val = shared_scalar(mask_val)
176177

177178
def get_output(self, train):
178179
X = self.get_input(train)
179-
return self.activation(X)
180+
mask = get_mask(X, self.mask_val)
181+
return mask * self.activation(X) + (1 - mask) * self.mask_val
180182

181183
def get_config(self):
182184
return {"name":self.__class__.__name__,

tests/manual/check_masked_recurrent.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,12 @@
117117
model2.add(TimeDistributedDense(4,4))
118118
model2.add(Activation('time_distributed_softmax'))
119119
model2.add(LSTM(4,4, return_sequences=True))
120+
model2.add(Activation('tanh'))
120121
model2.add(GRU(4,4, activation='softmax', return_sequences=True))
121122
model2.add(SimpleDeepRNN(4,4, depth=2, activation='relu', return_sequences=True))
122123
model2.add(SimpleRNN(4,4, activation='relu', return_sequences=False))
123124
model2.compile(loss='categorical_crossentropy',
124-
optimizer='rmsprop', theano_mode=theano.compile.mode.FAST_RUN)
125+
optimizer='rmsprop', theano_mode=theano.compile.mode.FAST_COMPILE)
125126
print("Compiled model2")
126127

127128
X2 = np.random.random_integers(1, 4, size=(2,5))

0 commit comments

Comments
 (0)