forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathattn_ft.py
140 lines (118 loc) · 7.3 KB
/
attn_ft.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch import nn
from functorch.dim import dims, dimlists, softmax, cat
import math
class Linear(nn.Linear):
def forward(self, input):
ci, co = dims()
b = dimlists()
result = (input[b, ci] * self.weight[co, ci]).sum(ci) + self.bias[co]
return result.order(b, co)
class BertSelfAttention(nn.Module):
def __init__(self, hidden_size, num_attention_heads,
attention_probs_dropout_prob, position_embedding_type=None,
max_position_embeddings=None, linear=Linear):
super().__init__()
if hidden_size % num_attention_heads != 0:
raise ValueError(
f"The hidden size ({hidden_size}) is not a multiple of the number of attention "
f"heads ({num_attention_heads})"
)
self.num_attention_heads = num_attention_heads
self.attention_head_size = int(hidden_size / num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = linear(hidden_size, self.all_head_size)
self.key = linear(hidden_size, self.all_head_size)
self.value = linear(hidden_size, self.all_head_size)
self.dropout_prob = attention_probs_dropout_prob
self.position_embedding_type = position_embedding_type
if self.position_embedding_type is not None:
assert max_position_embeddings is not None
self.max_position_embeddings = max_position_embeddings
self.distance_embedding = nn.Embedding(2 * max_position_embeddings - 1, self.attention_head_size)
def forward(
self,
hidden_states,
past_key_value=None,
):
# first run the encoding linear layers for q, k, v normally
# the meaning of a linear layer is well understood, so no need to use explicit dimensions
q = self.query(hidden_states)
k = self.key(hidden_states)
v = self.value(hidden_states)
# introduce values that represent each dimension. dimensions are 'first class'
# becaue they are actual python values introduced here
batch, query_sequence, key_sequence, heads, features = dims()
heads.size = self.num_attention_heads
# bind the positional dimensions in k, q, and v against
# our values. the sizes of each dimension are determined by this binding
# and when a dimension is used twice (e.g. batch), its size against both
# uses is checked for consistency.
# The group (heads, features) splits apart a single positional dimension
# into two dimensions. Since heads.size*features.size == q.size(2)
# and we specified heads.size, features.size is inferred here.
q = q[batch, query_sequence, [heads, features]]
k = k[batch, key_sequence, [heads, features]]
v = v[batch, key_sequence, [heads, features]]
# this option allows the model to attend to not just the elements of the current sequence
# but the previouse elements as well as additional tokens.
if past_key_value is not None:
extended_key_sequence = dims()
key_past = past_key_value[0][batch, heads, key_sequence, features]
value_past = past_key_value[1][batch, heads, key_sequence, features]
# cat introduces a new dimension exteneded_key_sequence, becuase it is twice as long
# as the original key_sequence
k = cat([key_past, k], key_sequence, extended_key_sequence)
v = cat([value_past, v], key_sequence, extended_key_sequence)
# for the rest of the function, we will just use extended_key_sequence in lieu of
# key_sequence
key_sequence = extended_key_sequence
# Take the dot product between "query" and "key" to get the raw attention scores.
# The actual outer-product and summation are explicitly represented here,
# and like einsum, will be pattern matched to an efficient matrix multiply op.
attention_scores = (q * k).sum(features) / math.sqrt(features.size)
# relative positional embeddings gave a unique embedding based on the distance between
# key and value tokens in the sequence, e.g.
# 0 1 2 3
# -1 0 1 2
# -2 -1 0 1
# -3 -2 -1 0
if self.position_embedding_type is not None:
# the value of a dimension object when used as a tensor is the indices along its dimension
# so we can directly subtract the two dimensions to get a 2D tensor of (query_sequence x key_sequence)
# with the distance between them
distance = query_sequence - key_sequence
assert key_sequence.size <= self.max_position_embeddings
# we can then use that as an indirect index into the embedding table values to look up the features for that index
# this is just a `gather` primitive op. The resulting tensor will
# have all the dimensions of embeddeding_idx (query_sequence x key_sequence),
# plus all the dimensions of `embed` that were not indirectly accessed (`embedding_range`).
# this form of indirect indexing is more strainghtforward than either advanced indexing or torch.gather which both
# have a lot of dependencies on the positions of indexing tensors.
positional_embedding = self.distance_embedding.weight[self.max_position_embeddings - 1 + distance, features]
if self.position_embedding_type == "relative_key":
# these were einsum ops in the positional code because they are not easy to fit to existing matmul operators
# eventhough they are degenerate matmuls
relative_position_scores = (q * positional_embedding).sum(features)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = (q * positional_embedding).sum(features)
relative_position_scores_key = (k * positional_embedding).sum(features)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_probs = attention_scores
# Normalize the attention scores to probabilities.
attention_probs = softmax(attention_scores, dim=key_sequence)
# # This is actually dropping out entire tokens to attend to, which might
# # seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = torch.nn.functional.dropout(attention_probs, p=self.dropout_prob)
# similarly, we can replace the matmul with a direct listing of the outer product, which makes it clear
# we are weighting the values v across all keys with the attention scores.
context_layer = (attention_probs * v).sum(key_sequence)
# finally, we convert back to a standard tensor by describing the layout of dimensions.
# working in reverse to with_dims, the (heads, features) group flattens the dimensions into a single one.
return context_layer.order(batch, query_sequence, [heads, features])