Skip to content

Commit

Permalink
Create attention_keras.py
Browse files Browse the repository at this point in the history
  • Loading branch information
bojone authored Jan 6, 2018
1 parent 012f6a1 commit 2656381
Showing 1 changed file with 74 additions and 0 deletions.
74 changes: 74 additions & 0 deletions attention_keras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#! -*- coding: utf-8 -*-

from keras import backend as K
from keras.engine.topology import Layer

class Attention(Layer):

def __init__(self, nb_head, size_per_head, **kwargs):
self.nb_head = nb_head
self.size_per_head = size_per_head
self.output_dim = nb_head*size_per_head
super(Attention, self).__init__(**kwargs)

def build(self, input_shape):
self.WQ = self.add_weight(name='WQ',
shape=(input_shape[0][-1], self.output_dim),
initializer='glorot_uniform',
trainable=True)
self.WK = self.add_weight(name='WK',
shape=(input_shape[1][-1], self.output_dim),
initializer='glorot_uniform',
trainable=True)
self.WV = self.add_weight(name='WV',
shape=(input_shape[2][-1], self.output_dim),
initializer='glorot_uniform',
trainable=True)
super(Attention, self).build(input_shape)

def Mask(self, inputs, seq_len, mode='mul'):
if seq_len == None:
return inputs
else:
mask = K.one_hot(seq_len[:,0], K.shape(inputs)[1])
mask = 1 - K.cumsum(mask, 1)
for _ in range(len(inputs.shape)-2):
mask = K.expand_dims(mask, 2)
if mode == 'mul':
return inputs * mask
if mode == 'add':
return inputs - (1 - mask) * 1e12

def call(self, x):
#如果只传入Q_seq,K_seq,V_seq,那么就不做Mask
#如果同时传入Q_seq,K_seq,V_seq,Q_len,V_len,那么对多余部分做Mask
if len(x) == 3:
Q_seq,K_seq,V_seq = x
Q_len,V_len = None,None
elif len(x) == 5:
Q_seq,K_seq,V_seq,Q_len,V_len = x
#对Q、K、V做线性变换
Q_seq = K.dot(Q_seq, self.WQ)
Q_seq = K.reshape(Q_seq, (-1, K.shape(Q_seq)[1], self.nb_head, self.size_per_head))
Q_seq = K.permute_dimensions(Q_seq, (0,2,1,3))
K_seq = K.dot(K_seq, self.WK)
K_seq = K.reshape(K_seq, (-1, K.shape(K_seq)[1], self.nb_head, self.size_per_head))
K_seq = K.permute_dimensions(K_seq, (0,2,1,3))
V_seq = K.dot(V_seq, self.WV)
V_seq = K.reshape(V_seq, (-1, K.shape(V_seq)[1], self.nb_head, self.size_per_head))
V_seq = K.permute_dimensions(V_seq, (0,2,1,3))
#计算内积,然后mask,然后softmax
A = K.batch_dot(Q_seq, K_seq, axes=[3,3])
A = K.permute_dimensions(A, (0,3,2,1))
A = self.Mask(A, V_len, 'add')
A = K.permute_dimensions(A, (0,3,2,1))
A = K.softmax(A)
#输出并mask
O_seq = K.batch_dot(A, V_seq, axes=[3,2])
O_seq = K.permute_dimensions(O_seq, (0,2,1,3))
O_seq = K.reshape(O_seq, (-1, K.shape(O_seq)[1], self.output_dim))
O_seq = self.Mask(O_seq, Q_len, 'mul')
return O_seq

def compute_output_shape(self, input_shape):
return (input_shape[0][0], input_shape[0][1], self.output_dim)

0 comments on commit 2656381

Please sign in to comment.