Skip to content

Commit

Permalink
Update attention_tf.py
Browse files Browse the repository at this point in the history
  • Loading branch information
bojone authored Jan 6, 2018
1 parent b837c48 commit 012f6a1
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions attention_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,13 @@ def Mask(inputs, seq_len, mode='mul'):
inputs是一个二阶或二阶以上的张量,即形如(batch_size,...,input_size)。
只对最后一个维度做矩阵乘法,即输出一个形如(batch_size,...,ouput_size)的张量。
'''
def Dense(inputs, ouput_size, seq_len=None):
def Dense(inputs, ouput_size, bias=True, seq_len=None):
input_size = int(inputs.shape[-1])
W = tf.Variable(tf.random_uniform([input_size, ouput_size], -0.05, 0.05))
b = tf.Variable(tf.random_uniform([ouput_size], -0.05, 0.05))
if bias:
b = tf.Variable(tf.random_uniform([ouput_size], -0.05, 0.05))
else:
b = 0
outputs = tf.matmul(tf.reshape(inputs, (-1, input_size)), W) + b
outputs = tf.reshape(outputs, \
tf.concat([tf.shape(inputs)[:-1], [ouput_size]], 0)
Expand All @@ -61,13 +64,13 @@ def Dense(inputs, ouput_size, seq_len=None):
'''
def Attention(Q, K, V, nb_head, size_per_head, Q_len=None, V_len=None):
#对Q、K、V分别作线性映射
Q = Dense(Q, nb_head * size_per_head)
Q = Dense(Q, nb_head * size_per_head, False)
Q = tf.reshape(Q, (-1, tf.shape(Q)[1], nb_head, size_per_head))
Q = tf.transpose(Q, [0, 2, 1, 3])
K = Dense(K, nb_head * size_per_head)
K = Dense(K, nb_head * size_per_head, False)
K = tf.reshape(K, (-1, tf.shape(K)[1], nb_head, size_per_head))
K = tf.transpose(K, [0, 2, 1, 3])
V = Dense(V, nb_head * size_per_head)
V = Dense(V, nb_head * size_per_head, False)
V = tf.reshape(V, (-1, tf.shape(V)[1], nb_head, size_per_head))
V = tf.transpose(V, [0, 2, 1, 3])
#计算内积,然后mask,然后softmax
Expand Down

0 comments on commit 012f6a1

Please sign in to comment.