-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathattention_module.py
288 lines (236 loc) · 9.43 KB
/
attention_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
from torch import FloatTensor
from torch.autograd import Variable
from torch.nn.functional import sigmoid, softmax
def mask3d(value, sizes):
"""Mask entries in value with 0 based on sizes.
Args
----
value: Tensor of size (B, N, D)
Tensor to be masked.
sizes: list of int
List giving the number of valid values for each item
in the batch. Positions beyond each size will be masked.
Returns
-------
value:
Masked value.
"""
v_mask = 0
v_unmask = 1
mask = value.data.new(value.size()).fill_(v_unmask)
n = mask.size(1)
for i, size in enumerate(sizes):
if size < n:
mask[i,size:,:] = v_mask
return Variable(mask) * value
def fill_context_mask(mask, sizes, v_mask, v_unmask):
"""Fill attention mask inplace for a variable length context.
Args
----
mask: Tensor of size (B, N, D)
Tensor to fill with mask values.
sizes: list[int]
List giving the size of the context for each item in
the batch. Positions beyond each size will be masked.
v_mask: float
Value to use for masked positions.
v_unmask: float
Value to use for unmasked positions.
Returns
-------
mask:
Filled with values in {v_mask, v_unmask}
"""
mask.fill_(v_unmask)
n_context = mask.size(2)
for i, size in enumerate(sizes):
if size < n_context:
mask[i,:,size:] = v_mask
return mask
def dot(a, b):
"""Compute the dot product between pairs of vectors in 3D Variables.
Args
----
a: Variable of size (B, M, D)
b: Variable of size (B, N, D)
Returns
-------
c: Variable of size (B, M, N)
c[i,j,k] = dot(a[i,j], b[i,k])
"""
return a.bmm(b.transpose(1, 2))
def attend(query, context, value=None, score='dot', normalize='softmax',
context_sizes=None, context_mask=None, return_weight=False
):
"""Attend to value (or context) by scoring each query and context.
Args
----
query: Variable of size (B, M, D1)
Batch of M query vectors.
context: Variable of size (B, N, D2)
Batch of N context vectors.
value: Variable of size (B, N, P), default=None
If given, the output vectors will be weighted
combinations of the value vectors.
Otherwise, the context vectors will be used.
score: str or callable, default='dot'
If score == 'dot', scores are computed
as the dot product between context and
query vectors. This Requires D1 == D2.
Otherwise, score should be a callable:
query context score
(B,M,D1) (B,N,D2) -> (B,M,N)
normalize: str, default='softmax'
One of 'softmax', 'sigmoid', or 'identity'.
Name of function used to map scores to weights.
context_mask: Tensor of (B, M, N), default=None
A Tensor used to mask context. Masked
and unmasked entries should be filled
appropriately for the normalization function.
context_sizes: list[int], default=None,
List giving the size of context for each item
in the batch and used to compute a context_mask.
If context_mask or context_sizes are not given,
context is assumed to have fixed size.
return_weight: bool, default=False
If True, return the attention weight Tensor.
Returns
-------
output: Variable of size (B, M, P)
If return_weight is False.
weight, output: Variable of size (B, M, N), Variable of size (B, M, P)
If return_weight is True.
About
-----
Attention is used to focus processing on a particular region of input.
This function implements the most common attention mechanism [1, 2, 3],
which produces an output by taking a weighted combination of value vectors
with weights from by a scoring function operating over pairs of query and
context vectors.
Given query vector `q`, context vectors `c_1,...,c_n`, and value vectors
`v_1,...,v_n` the attention score of `q` with `c_i` is given by
s_i = f(q, c_i)
Frequently, `f` is given by the dot product between query and context vectors.
s_i = q^T c_i
The scores are passed through a normalization functions g.
This is normally the softmax function.
w_i = g(s_1,...,s_n)_i
Finally, the output is computed as a weighted
combination of the values with the normalized scores.
z = sum_{i=1}^n w_i * v_i
In many applications [4, 5] the context and value vectors are the same, `v_i = c_i`.
Sizes
-----
This function accepts batches of size `B` containing
`M` query vectors of dimension `D1`,
`N` context vectors of dimension `D2`,
and optionally `N` value vectors of dimension `P`.
Variable Length Contexts
------------------------
If the number of context vectors varies within a batch, a context
can be ignored by forcing the corresponding weight to be zero.
In the case of the softmax, this can be achieved by adding negative
infinity to the corresponding score before normalization.
Similarly, for elementwise normalization functions the weights can
be multiplied by an appropriate {0,1} mask after normalization.
To facilitate the above behavior, a context mask, with entries
in `{-inf, 0}` or `{0, 1}` depending on the normalization function,
can be passed to this function. The masks should have size `(B, M, N)`.
Alternatively, a list can be passed giving the size of the context for
each item in the batch. Appropriate masks will be created from these lists.
Note that the size of output does not depend on the number of context vectors.
Because of this, context positions are truly unaccounted for in the output.
References
----------
[1](https://arxiv.org/abs/1410.5401)
@article{graves2014neural,
title={Neural turing machines},
author={Graves, Alex and Wayne, Greg and Danihelka, Ivo},
journal={arXiv preprint arXiv:1410.5401},
year={2014}
}
[2](https://arxiv.org/abs/1503.08895)
@inproceedings{sukhbaatar2015end,
title={End-to-end memory networks},
author={Sukhbaatar, Sainbayar and Weston, Jason and Fergus, Rob and others},
booktitle={Advances in neural information processing systems},
pages={2440--2448},
year={2015}
}
[3](https://distill.pub/2016/augmented-rnns/)
@article{olah2016attention,
title={Attention and augmented recurrent neural networks},
author={Olah, Chris and Carter, Shan},
journal={Distill},
volume={1},
number={9},
pages={e1},
year={2016}
}
[4](https://arxiv.org/abs/1409.0473)
@article{bahdanau2014neural,
title={Neural machine translation by jointly learning to align and translate},
author={Bahdanau, Dzmitry and Cho, Kyunghyun and Bengio, Yoshua},
journal={arXiv preprint arXiv:1409.0473},
year={2014}
}
[5](https://arxiv.org/abs/1506.03134)
@inproceedings{vinyals2015pointer,
title={Pointer networks},
author={Vinyals, Oriol and Fortunato, Meire and Jaitly, Navdeep},
booktitle={Advances in Neural Information Processing Systems},
pages={2692--2700},
year={2015}
}
"""
q, c, v = query, context, value
if v is None:
v = c
batch_size_q, n_q, dim_q = q.size()
batch_size_c, n_c, dim_c = c.size()
batch_size_v, n_v, dim_v = v.size()
if not (batch_size_q == batch_size_c == batch_size_v):
msg = 'batch size mismatch (query: {}, context: {}, value: {})'
raise ValueError(msg.format(q.size(), c.size(), v.size()))
batch_size = batch_size_q
# Compute scores
if score == 'dot':
s = dot(q, c)
elif callable(score):
s = score(q, c)
else:
raise ValueError(f'unknown score function: {score}')
# Normalize scores and mask contexts
if normalize == 'softmax':
if context_mask is not None:
s = context_mask + s
elif context_sizes is not None:
context_mask = s.data.new(batch_size, n_q, n_c)
context_mask = fill_context_mask(context_mask,
sizes=context_sizes,
v_mask=float('-inf'),
v_unmask=0
)
s = context_mask + s
s_flat = s.view(batch_size * n_q, n_c)
w_flat = softmax(s_flat, dim=1)
w = w_flat.view(batch_size, n_q, n_c)
elif normalize == 'sigmoid' or normalize == 'identity':
w = sigmoid(s) if normalize == 'sigmoid' else s
if context_mask is not None:
w = context_mask * w
elif context_sizes is not None:
context_mask = s.data.new(batch_size, n_q, n_c)
context_mask = fill_context_mask(context_mask,
sizes=context_sizes,
v_mask=0,
v_unmask=1
)
w = context_mask * w
else:
raise ValueError(f'unknown normalize function: {normalize}')
# Combine
z = w.bmm(v)
if return_weight:
return w, z
return z