-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
attention_decoder.py from a public repo
- Loading branch information
Showing
1 changed file
with
335 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,335 @@ | ||
import tensorflow as tf | ||
from keras import backend as K | ||
from keras import regularizers, constraints, initializers, activations | ||
from keras.layers.recurrent import Recurrent | ||
from keras.engine import InputSpec | ||
|
||
|
||
tfPrint = lambda d, T: tf.Print(input_=T, data=[T, tf.shape(T)], message=d) | ||
|
||
|
||
def time_distributed_dense(x, w, b=None, dropout=None, | ||
input_dim=None, output_dim=None, timesteps=None): | ||
'''Apply y.w + b for every temporal slice y of x. | ||
''' | ||
if not input_dim: | ||
# won't work with TensorFlow | ||
input_dim = K.shape(x)[2] | ||
if not timesteps: | ||
# won't work with TensorFlow | ||
timesteps = K.shape(x)[1] | ||
if not output_dim: | ||
# won't work with TensorFlow | ||
output_dim = K.shape(w)[1] | ||
|
||
if dropout: | ||
# apply the same dropout pattern at every timestep | ||
ones = K.ones_like(K.reshape(x[:, 0, :], (-1, input_dim))) | ||
dropout_matrix = K.dropout(ones, dropout) | ||
expanded_dropout_matrix = K.repeat(dropout_matrix, timesteps) | ||
x *= expanded_dropout_matrix | ||
|
||
# collapse time dimension and batch dimension together | ||
x = K.reshape(x, (-1, input_dim)) | ||
|
||
x = K.dot(x, w) | ||
if b: | ||
x = x + b | ||
# reshape to 3D tensor | ||
x = K.reshape(x, (-1, timesteps, output_dim)) | ||
return x | ||
|
||
class AttentionDecoder(Recurrent): | ||
|
||
def __init__(self, units, output_dim, | ||
activation='tanh', | ||
return_probabilities=False, | ||
name='AttentionDecoder', | ||
kernel_initializer='glorot_uniform', | ||
recurrent_initializer='orthogonal', | ||
bias_initializer='zeros', | ||
kernel_regularizer=None, | ||
bias_regularizer=None, | ||
activity_regularizer=None, | ||
kernel_constraint=None, | ||
bias_constraint=None, | ||
**kwargs): | ||
""" | ||
Implements an AttentionDecoder that takes in a sequence encoded by an | ||
encoder and outputs the decoded states | ||
:param units: dimension of the hidden state and the attention matrices | ||
:param output_dim: the number of labels in the output space | ||
references: | ||
Bahdanau, Dzmitry, Kyunghyun Cho, and Yoshua Bengio. | ||
"Neural machine translation by jointly learning to align and translate." | ||
arXiv preprint arXiv:1409.0473 (2014). | ||
""" | ||
self.units = units | ||
self.output_dim = output_dim | ||
self.return_probabilities = return_probabilities | ||
self.activation = activations.get(activation) | ||
self.kernel_initializer = initializers.get(kernel_initializer) | ||
self.recurrent_initializer = initializers.get(recurrent_initializer) | ||
self.bias_initializer = initializers.get(bias_initializer) | ||
|
||
self.kernel_regularizer = regularizers.get(kernel_regularizer) | ||
self.recurrent_regularizer = regularizers.get(kernel_regularizer) | ||
self.bias_regularizer = regularizers.get(bias_regularizer) | ||
self.activity_regularizer = regularizers.get(activity_regularizer) | ||
|
||
self.kernel_constraint = constraints.get(kernel_constraint) | ||
self.recurrent_constraint = constraints.get(kernel_constraint) | ||
self.bias_constraint = constraints.get(bias_constraint) | ||
|
||
super(AttentionDecoder, self).__init__(**kwargs) | ||
self.name = name | ||
self.return_sequences = True # must return sequences | ||
|
||
def build(self, input_shape): | ||
""" | ||
See Appendix 2 of Bahdanau 2014, arXiv:1409.0473 | ||
for model details that correspond to the matrices here. | ||
""" | ||
|
||
self.batch_size, self.timesteps, self.input_dim = input_shape | ||
|
||
if self.stateful: | ||
super(AttentionDecoder, self).reset_states() | ||
|
||
self.states = [None, None] # y, s | ||
|
||
""" | ||
Matrices for creating the context vector | ||
""" | ||
|
||
self.V_a = self.add_weight(shape=(self.units,), | ||
name='V_a', | ||
initializer=self.kernel_initializer, | ||
regularizer=self.kernel_regularizer, | ||
constraint=self.kernel_constraint) | ||
self.W_a = self.add_weight(shape=(self.units, self.units), | ||
name='W_a', | ||
initializer=self.kernel_initializer, | ||
regularizer=self.kernel_regularizer, | ||
constraint=self.kernel_constraint) | ||
self.U_a = self.add_weight(shape=(self.input_dim, self.units), | ||
name='U_a', | ||
initializer=self.kernel_initializer, | ||
regularizer=self.kernel_regularizer, | ||
constraint=self.kernel_constraint) | ||
self.b_a = self.add_weight(shape=(self.units,), | ||
name='b_a', | ||
initializer=self.bias_initializer, | ||
regularizer=self.bias_regularizer, | ||
constraint=self.bias_constraint) | ||
""" | ||
Matrices for the r (reset) gate | ||
""" | ||
self.C_r = self.add_weight(shape=(self.input_dim, self.units), | ||
name='C_r', | ||
initializer=self.recurrent_initializer, | ||
regularizer=self.recurrent_regularizer, | ||
constraint=self.recurrent_constraint) | ||
self.U_r = self.add_weight(shape=(self.units, self.units), | ||
name='U_r', | ||
initializer=self.recurrent_initializer, | ||
regularizer=self.recurrent_regularizer, | ||
constraint=self.recurrent_constraint) | ||
self.W_r = self.add_weight(shape=(self.output_dim, self.units), | ||
name='W_r', | ||
initializer=self.recurrent_initializer, | ||
regularizer=self.recurrent_regularizer, | ||
constraint=self.recurrent_constraint) | ||
self.b_r = self.add_weight(shape=(self.units, ), | ||
name='b_r', | ||
initializer=self.bias_initializer, | ||
regularizer=self.bias_regularizer, | ||
constraint=self.bias_constraint) | ||
|
||
""" | ||
Matrices for the z (update) gate | ||
""" | ||
self.C_z = self.add_weight(shape=(self.input_dim, self.units), | ||
name='C_z', | ||
initializer=self.recurrent_initializer, | ||
regularizer=self.recurrent_regularizer, | ||
constraint=self.recurrent_constraint) | ||
self.U_z = self.add_weight(shape=(self.units, self.units), | ||
name='U_z', | ||
initializer=self.recurrent_initializer, | ||
regularizer=self.recurrent_regularizer, | ||
constraint=self.recurrent_constraint) | ||
self.W_z = self.add_weight(shape=(self.output_dim, self.units), | ||
name='W_z', | ||
initializer=self.recurrent_initializer, | ||
regularizer=self.recurrent_regularizer, | ||
constraint=self.recurrent_constraint) | ||
self.b_z = self.add_weight(shape=(self.units, ), | ||
name='b_z', | ||
initializer=self.bias_initializer, | ||
regularizer=self.bias_regularizer, | ||
constraint=self.bias_constraint) | ||
""" | ||
Matrices for the proposal | ||
""" | ||
self.C_p = self.add_weight(shape=(self.input_dim, self.units), | ||
name='C_p', | ||
initializer=self.recurrent_initializer, | ||
regularizer=self.recurrent_regularizer, | ||
constraint=self.recurrent_constraint) | ||
self.U_p = self.add_weight(shape=(self.units, self.units), | ||
name='U_p', | ||
initializer=self.recurrent_initializer, | ||
regularizer=self.recurrent_regularizer, | ||
constraint=self.recurrent_constraint) | ||
self.W_p = self.add_weight(shape=(self.output_dim, self.units), | ||
name='W_p', | ||
initializer=self.recurrent_initializer, | ||
regularizer=self.recurrent_regularizer, | ||
constraint=self.recurrent_constraint) | ||
self.b_p = self.add_weight(shape=(self.units, ), | ||
name='b_p', | ||
initializer=self.bias_initializer, | ||
regularizer=self.bias_regularizer, | ||
constraint=self.bias_constraint) | ||
""" | ||
Matrices for making the final prediction vector | ||
""" | ||
self.C_o = self.add_weight(shape=(self.input_dim, self.output_dim), | ||
name='C_o', | ||
initializer=self.recurrent_initializer, | ||
regularizer=self.recurrent_regularizer, | ||
constraint=self.recurrent_constraint) | ||
self.U_o = self.add_weight(shape=(self.units, self.output_dim), | ||
name='U_o', | ||
initializer=self.recurrent_initializer, | ||
regularizer=self.recurrent_regularizer, | ||
constraint=self.recurrent_constraint) | ||
self.W_o = self.add_weight(shape=(self.output_dim, self.output_dim), | ||
name='W_o', | ||
initializer=self.recurrent_initializer, | ||
regularizer=self.recurrent_regularizer, | ||
constraint=self.recurrent_constraint) | ||
self.b_o = self.add_weight(shape=(self.output_dim, ), | ||
name='b_o', | ||
initializer=self.bias_initializer, | ||
regularizer=self.bias_regularizer, | ||
constraint=self.bias_constraint) | ||
|
||
# For creating the initial state: | ||
self.W_s = self.add_weight(shape=(self.input_dim, self.units), | ||
name='W_s', | ||
initializer=self.recurrent_initializer, | ||
regularizer=self.recurrent_regularizer, | ||
constraint=self.recurrent_constraint) | ||
|
||
self.input_spec = [ | ||
InputSpec(shape=(self.batch_size, self.timesteps, self.input_dim))] | ||
self.built = True | ||
|
||
def call(self, x): | ||
# store the whole sequence so we can "attend" to it at each timestep | ||
self.x_seq = x | ||
|
||
# apply the a dense layer over the time dimension of the sequence | ||
# do it here because it doesn't depend on any previous steps | ||
# thefore we can save computation time: | ||
self._uxpb = time_distributed_dense(self.x_seq, self.U_a, b=self.b_a, | ||
input_dim=self.input_dim, | ||
timesteps=self.timesteps, | ||
output_dim=self.units) | ||
|
||
return super(AttentionDecoder, self).call(x) | ||
|
||
def get_initial_state(self, inputs): | ||
# apply the matrix on the first time step to get the initial s0. | ||
s0 = activations.tanh(K.dot(inputs[:, 0], self.W_s)) | ||
|
||
# from keras.layers.recurrent to initialize a vector of (batchsize, | ||
# output_dim) | ||
y0 = K.zeros_like(inputs) # (samples, timesteps, input_dims) | ||
y0 = K.sum(y0, axis=(1, 2)) # (samples, ) | ||
y0 = K.expand_dims(y0) # (samples, 1) | ||
y0 = K.tile(y0, [1, self.output_dim]) | ||
|
||
return [y0, s0] | ||
|
||
def step(self, x, states): | ||
|
||
ytm, stm = states | ||
|
||
# repeat the hidden state to the length of the sequence | ||
_stm = K.repeat(stm, self.timesteps) | ||
|
||
# now multiplty the weight matrix with the repeated hidden state | ||
_Wxstm = K.dot(_stm, self.W_a) | ||
|
||
# calculate the attention probabilities | ||
# this relates how much other timesteps contributed to this one. | ||
et = K.dot(activations.tanh(_Wxstm + self._uxpb), | ||
K.expand_dims(self.V_a)) | ||
at = K.exp(et) | ||
at_sum = K.sum(at, axis=1) | ||
at_sum_repeated = K.repeat(at_sum, self.timesteps) | ||
at /= at_sum_repeated # vector of size (batchsize, timesteps, 1) | ||
|
||
# calculate the context vector | ||
context = K.squeeze(K.batch_dot(at, self.x_seq, axes=1), axis=1) | ||
# ~~~> calculate new hidden state | ||
# first calculate the "r" gate: | ||
|
||
rt = activations.sigmoid( | ||
K.dot(ytm, self.W_r) | ||
+ K.dot(stm, self.U_r) | ||
+ K.dot(context, self.C_r) | ||
+ self.b_r) | ||
|
||
# now calculate the "z" gate | ||
zt = activations.sigmoid( | ||
K.dot(ytm, self.W_z) | ||
+ K.dot(stm, self.U_z) | ||
+ K.dot(context, self.C_z) | ||
+ self.b_z) | ||
|
||
# calculate the proposal hidden state: | ||
s_tp = activations.tanh( | ||
K.dot(ytm, self.W_p) | ||
+ K.dot((rt * stm), self.U_p) | ||
+ K.dot(context, self.C_p) | ||
+ self.b_p) | ||
|
||
# new hidden state: | ||
st = (1-zt)*stm + zt * s_tp | ||
|
||
yt = activations.softmax( | ||
K.dot(ytm, self.W_o) | ||
+ K.dot(stm, self.U_o) | ||
+ K.dot(context, self.C_o) | ||
+ self.b_o) | ||
|
||
if self.return_probabilities: | ||
return at, [yt, st] | ||
else: | ||
return yt, [yt, st] | ||
|
||
def compute_output_shape(self, input_shape): | ||
""" | ||
For Keras internal compatability checking | ||
""" | ||
if self.return_probabilities: | ||
return (None, self.timesteps, self.timesteps) | ||
else: | ||
return (None, self.timesteps, self.output_dim) | ||
|
||
def get_config(self): | ||
""" | ||
For rebuilding models on load time. | ||
""" | ||
config = { | ||
'output_dim': self.output_dim, | ||
'units': self.units, | ||
'return_probabilities': self.return_probabilities | ||
} | ||
base_config = super(AttentionDecoder, self).get_config() | ||
return dict(list(base_config.items()) + list(config.items())) |