-
Notifications
You must be signed in to change notification settings - Fork 2
/
layers.py
128 lines (118 loc) · 4.82 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import numpy as np
import tensorflow as tf
from bert4keras.layers import Input, Embedding, Masking, Dense, Lambda, LayerNormalization, Dropout, Concatenate1D, MultiHeadAttention
from bert4keras.layers import RelativePositionEmbedding, Add, FeedForward
from keras import initializers, activations
from bert4keras.backend import sequence_masking, recompute_grad
from bert4keras.backend import K
class RelativePositionEmbeddingT5(RelativePositionEmbedding):
"""Google T5的相对位置编码
来自论文:https://arxiv.org/abs/1910.10683
"""
def __init__(
self,
input_dim,
output_dim,
max_distance=128,
bidirectional=True,
embeddings_initializer='zeros',
last_one=False,
**kwargs
):
super(RelativePositionEmbeddingT5,
self).__init__(input_dim, output_dim, **kwargs)
self.max_distance = max_distance
self.bidirectional = bidirectional
self.last_one = last_one
def compute_position_ids(self, inputs):
"""T5的相对位置分桶(直接翻译自官方T5源码)
"""
q, v = inputs
# 计算位置差
q_idxs = K.arange(0, K.shape(q)[1], dtype='int32')
q_idxs = K.expand_dims(q_idxs, 1)
v_idxs = K.arange(0, K.shape(v)[1], dtype='int32')
v_idxs = K.expand_dims(v_idxs, 0)
pos_ids = v_idxs - q_idxs
# 后处理操作
num_buckets, max_distance = self.input_dim, self.max_distance
ret = 0
n = -pos_ids
if self.bidirectional:
num_buckets //= 2
ret += K.cast(K.less(n, 0), 'int32') * num_buckets
n = K.abs(n)
else:
n = K.maximum(n, 0)
# now n is in the range [0, inf)
max_exact = num_buckets // 2
is_small = K.less(n, max_exact)
val_if_large = max_exact + K.cast(
K.log(K.cast(n, K.floatx()) / max_exact) /
np.log(max_distance / max_exact) * (num_buckets - max_exact),
'int32',
)
val_if_large = K.minimum(val_if_large, num_buckets - 1)
ret += K.switch(is_small, n, val_if_large)
if self.last_one:
return ret[-1:]
return ret
def get_config(self):
config = {
'max_distance': self.max_distance,
'bidirectional': self.bidirectional,
}
base_config = super(RelativePositionEmbeddingT5, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class MultiHeadAttentionCache(MultiHeadAttention):
"""计算qw 时,对rotary position 进行截断,保留最后的位置编码
"""
def pay_attention_to(self, inputs, mask=None, **kwargs):
"""实现标准的乘性多头注意力
a_bias: 对attention矩阵的bias。
不同的attention bias对应不同的应用。
p_bias: 在attention里的位置偏置。
一般用来指定相对位置编码的种类。
说明: 这里单独分离出pay_attention_to函数,是为了方便
继承此类来定义不同形式的atttention;此处要求
返回o.shape=(batch_size, seq_len, heads, head_size)。
"""
(qw, kw, vw), n = inputs[:3], 3
q_mask, v_mask = mask
a_bias, p_bias = kwargs.get('a_bias'), kwargs.get('p_bias')
if a_bias:
a_bias = inputs[n]
n += 1
if p_bias == 'rotary':
cos_pos = K.repeat_elements(inputs[n][..., None, 1::2], 2, -1)
sin_pos = K.repeat_elements(inputs[n][..., None, ::2], 2, -1)
qw2 = K.stack([-qw[..., 1::2], qw[..., ::2]], 4)
qw2 = K.reshape(qw2, K.shape(qw))
q_len = K.shape(qw2)[1]
qw = qw * cos_pos[:, -q_len:] + qw2 * sin_pos[:, -q_len:]
kw2 = K.stack([-kw[..., 1::2], kw[..., ::2]], 4)
kw2 = K.reshape(kw2, K.shape(kw))
kw = kw * cos_pos + kw2 * sin_pos
# Attention
a = tf.einsum('bjhd,bkhd->bhjk', qw, kw)
# 处理位置编码
if p_bias == 'typical_relative':
position_bias = inputs[n]
a = a + tf.einsum('bjhd,jkd->bhjk', qw, position_bias)
elif p_bias == 't5_relative':
position_bias = K.permute_dimensions(inputs[n], (2, 0, 1))
a = a + K.expand_dims(position_bias, 0)
# Attention(续)
if self.attention_scale:
a = a / self.key_size**0.5
if a_bias is not None:
a = a + a_bias
a = sequence_masking(a, v_mask, '-inf', -1)
A = K.softmax(a)
if self.attention_dropout:
A = Dropout(self.attention_dropout)(A)
# 完成输出
o = tf.einsum('bhjk,bkhd->bjhd', A, vw)
if p_bias == 'typical_relative':
o = o + tf.einsum('bhjk,jkd->bjhd', A, position_bias)
return o, a