From 9c04dd80d2141284333ba22e838f991cdeacb5d7 Mon Sep 17 00:00:00 2001 From: se-yi Date: Wed, 5 Dec 2018 15:08:06 -0500 Subject: [PATCH] attention_decoder.py from a public repo --- models/attention_decoder.py | 335 ++++++++++++++++++++++++++++++++++++ 1 file changed, 335 insertions(+) create mode 100644 models/attention_decoder.py diff --git a/models/attention_decoder.py b/models/attention_decoder.py new file mode 100644 index 0000000..b9b81aa --- /dev/null +++ b/models/attention_decoder.py @@ -0,0 +1,335 @@ +import tensorflow as tf +from keras import backend as K +from keras import regularizers, constraints, initializers, activations +from keras.layers.recurrent import Recurrent +from keras.engine import InputSpec + + +tfPrint = lambda d, T: tf.Print(input_=T, data=[T, tf.shape(T)], message=d) + + +def time_distributed_dense(x, w, b=None, dropout=None, + input_dim=None, output_dim=None, timesteps=None): + '''Apply y.w + b for every temporal slice y of x. + ''' + if not input_dim: + # won't work with TensorFlow + input_dim = K.shape(x)[2] + if not timesteps: + # won't work with TensorFlow + timesteps = K.shape(x)[1] + if not output_dim: + # won't work with TensorFlow + output_dim = K.shape(w)[1] + + if dropout: + # apply the same dropout pattern at every timestep + ones = K.ones_like(K.reshape(x[:, 0, :], (-1, input_dim))) + dropout_matrix = K.dropout(ones, dropout) + expanded_dropout_matrix = K.repeat(dropout_matrix, timesteps) + x *= expanded_dropout_matrix + + # collapse time dimension and batch dimension together + x = K.reshape(x, (-1, input_dim)) + + x = K.dot(x, w) + if b: + x = x + b + # reshape to 3D tensor + x = K.reshape(x, (-1, timesteps, output_dim)) + return x + +class AttentionDecoder(Recurrent): + + def __init__(self, units, output_dim, + activation='tanh', + return_probabilities=False, + name='AttentionDecoder', + kernel_initializer='glorot_uniform', + recurrent_initializer='orthogonal', + bias_initializer='zeros', + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + **kwargs): + """ + Implements an AttentionDecoder that takes in a sequence encoded by an + encoder and outputs the decoded states + :param units: dimension of the hidden state and the attention matrices + :param output_dim: the number of labels in the output space + + references: + Bahdanau, Dzmitry, Kyunghyun Cho, and Yoshua Bengio. + "Neural machine translation by jointly learning to align and translate." + arXiv preprint arXiv:1409.0473 (2014). + """ + self.units = units + self.output_dim = output_dim + self.return_probabilities = return_probabilities + self.activation = activations.get(activation) + self.kernel_initializer = initializers.get(kernel_initializer) + self.recurrent_initializer = initializers.get(recurrent_initializer) + self.bias_initializer = initializers.get(bias_initializer) + + self.kernel_regularizer = regularizers.get(kernel_regularizer) + self.recurrent_regularizer = regularizers.get(kernel_regularizer) + self.bias_regularizer = regularizers.get(bias_regularizer) + self.activity_regularizer = regularizers.get(activity_regularizer) + + self.kernel_constraint = constraints.get(kernel_constraint) + self.recurrent_constraint = constraints.get(kernel_constraint) + self.bias_constraint = constraints.get(bias_constraint) + + super(AttentionDecoder, self).__init__(**kwargs) + self.name = name + self.return_sequences = True # must return sequences + + def build(self, input_shape): + """ + See Appendix 2 of Bahdanau 2014, arXiv:1409.0473 + for model details that correspond to the matrices here. + """ + + self.batch_size, self.timesteps, self.input_dim = input_shape + + if self.stateful: + super(AttentionDecoder, self).reset_states() + + self.states = [None, None] # y, s + + """ + Matrices for creating the context vector + """ + + self.V_a = self.add_weight(shape=(self.units,), + name='V_a', + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint) + self.W_a = self.add_weight(shape=(self.units, self.units), + name='W_a', + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint) + self.U_a = self.add_weight(shape=(self.input_dim, self.units), + name='U_a', + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint) + self.b_a = self.add_weight(shape=(self.units,), + name='b_a', + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint) + """ + Matrices for the r (reset) gate + """ + self.C_r = self.add_weight(shape=(self.input_dim, self.units), + name='C_r', + initializer=self.recurrent_initializer, + regularizer=self.recurrent_regularizer, + constraint=self.recurrent_constraint) + self.U_r = self.add_weight(shape=(self.units, self.units), + name='U_r', + initializer=self.recurrent_initializer, + regularizer=self.recurrent_regularizer, + constraint=self.recurrent_constraint) + self.W_r = self.add_weight(shape=(self.output_dim, self.units), + name='W_r', + initializer=self.recurrent_initializer, + regularizer=self.recurrent_regularizer, + constraint=self.recurrent_constraint) + self.b_r = self.add_weight(shape=(self.units, ), + name='b_r', + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint) + + """ + Matrices for the z (update) gate + """ + self.C_z = self.add_weight(shape=(self.input_dim, self.units), + name='C_z', + initializer=self.recurrent_initializer, + regularizer=self.recurrent_regularizer, + constraint=self.recurrent_constraint) + self.U_z = self.add_weight(shape=(self.units, self.units), + name='U_z', + initializer=self.recurrent_initializer, + regularizer=self.recurrent_regularizer, + constraint=self.recurrent_constraint) + self.W_z = self.add_weight(shape=(self.output_dim, self.units), + name='W_z', + initializer=self.recurrent_initializer, + regularizer=self.recurrent_regularizer, + constraint=self.recurrent_constraint) + self.b_z = self.add_weight(shape=(self.units, ), + name='b_z', + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint) + """ + Matrices for the proposal + """ + self.C_p = self.add_weight(shape=(self.input_dim, self.units), + name='C_p', + initializer=self.recurrent_initializer, + regularizer=self.recurrent_regularizer, + constraint=self.recurrent_constraint) + self.U_p = self.add_weight(shape=(self.units, self.units), + name='U_p', + initializer=self.recurrent_initializer, + regularizer=self.recurrent_regularizer, + constraint=self.recurrent_constraint) + self.W_p = self.add_weight(shape=(self.output_dim, self.units), + name='W_p', + initializer=self.recurrent_initializer, + regularizer=self.recurrent_regularizer, + constraint=self.recurrent_constraint) + self.b_p = self.add_weight(shape=(self.units, ), + name='b_p', + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint) + """ + Matrices for making the final prediction vector + """ + self.C_o = self.add_weight(shape=(self.input_dim, self.output_dim), + name='C_o', + initializer=self.recurrent_initializer, + regularizer=self.recurrent_regularizer, + constraint=self.recurrent_constraint) + self.U_o = self.add_weight(shape=(self.units, self.output_dim), + name='U_o', + initializer=self.recurrent_initializer, + regularizer=self.recurrent_regularizer, + constraint=self.recurrent_constraint) + self.W_o = self.add_weight(shape=(self.output_dim, self.output_dim), + name='W_o', + initializer=self.recurrent_initializer, + regularizer=self.recurrent_regularizer, + constraint=self.recurrent_constraint) + self.b_o = self.add_weight(shape=(self.output_dim, ), + name='b_o', + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint) + + # For creating the initial state: + self.W_s = self.add_weight(shape=(self.input_dim, self.units), + name='W_s', + initializer=self.recurrent_initializer, + regularizer=self.recurrent_regularizer, + constraint=self.recurrent_constraint) + + self.input_spec = [ + InputSpec(shape=(self.batch_size, self.timesteps, self.input_dim))] + self.built = True + + def call(self, x): + # store the whole sequence so we can "attend" to it at each timestep + self.x_seq = x + + # apply the a dense layer over the time dimension of the sequence + # do it here because it doesn't depend on any previous steps + # thefore we can save computation time: + self._uxpb = time_distributed_dense(self.x_seq, self.U_a, b=self.b_a, + input_dim=self.input_dim, + timesteps=self.timesteps, + output_dim=self.units) + + return super(AttentionDecoder, self).call(x) + + def get_initial_state(self, inputs): + # apply the matrix on the first time step to get the initial s0. + s0 = activations.tanh(K.dot(inputs[:, 0], self.W_s)) + + # from keras.layers.recurrent to initialize a vector of (batchsize, + # output_dim) + y0 = K.zeros_like(inputs) # (samples, timesteps, input_dims) + y0 = K.sum(y0, axis=(1, 2)) # (samples, ) + y0 = K.expand_dims(y0) # (samples, 1) + y0 = K.tile(y0, [1, self.output_dim]) + + return [y0, s0] + + def step(self, x, states): + + ytm, stm = states + + # repeat the hidden state to the length of the sequence + _stm = K.repeat(stm, self.timesteps) + + # now multiplty the weight matrix with the repeated hidden state + _Wxstm = K.dot(_stm, self.W_a) + + # calculate the attention probabilities + # this relates how much other timesteps contributed to this one. + et = K.dot(activations.tanh(_Wxstm + self._uxpb), + K.expand_dims(self.V_a)) + at = K.exp(et) + at_sum = K.sum(at, axis=1) + at_sum_repeated = K.repeat(at_sum, self.timesteps) + at /= at_sum_repeated # vector of size (batchsize, timesteps, 1) + + # calculate the context vector + context = K.squeeze(K.batch_dot(at, self.x_seq, axes=1), axis=1) + # ~~~> calculate new hidden state + # first calculate the "r" gate: + + rt = activations.sigmoid( + K.dot(ytm, self.W_r) + + K.dot(stm, self.U_r) + + K.dot(context, self.C_r) + + self.b_r) + + # now calculate the "z" gate + zt = activations.sigmoid( + K.dot(ytm, self.W_z) + + K.dot(stm, self.U_z) + + K.dot(context, self.C_z) + + self.b_z) + + # calculate the proposal hidden state: + s_tp = activations.tanh( + K.dot(ytm, self.W_p) + + K.dot((rt * stm), self.U_p) + + K.dot(context, self.C_p) + + self.b_p) + + # new hidden state: + st = (1-zt)*stm + zt * s_tp + + yt = activations.softmax( + K.dot(ytm, self.W_o) + + K.dot(stm, self.U_o) + + K.dot(context, self.C_o) + + self.b_o) + + if self.return_probabilities: + return at, [yt, st] + else: + return yt, [yt, st] + + def compute_output_shape(self, input_shape): + """ + For Keras internal compatability checking + """ + if self.return_probabilities: + return (None, self.timesteps, self.timesteps) + else: + return (None, self.timesteps, self.output_dim) + + def get_config(self): + """ + For rebuilding models on load time. + """ + config = { + 'output_dim': self.output_dim, + 'units': self.units, + 'return_probabilities': self.return_probabilities + } + base_config = super(AttentionDecoder, self).get_config() + return dict(list(base_config.items()) + list(config.items()))