-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmulti_head_attention.py
93 lines (78 loc) · 4.24 KB
/
multi_head_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import torch.nn as nn
import math
from torch import Tensor
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int = 512, n_heads: int = 8, dropout: float = 0.1):
"""
Args:
d_model: dimension of embeddings
n_heads: number of self attention heads
dropout: probability of dropout occurring
"""
super().__init__()
assert d_model % n_heads == 0 # ensure an even num of heads
self.d_model = d_model # 512 dim
self.n_heads = n_heads # 8 heads
self.d_key = d_model // n_heads # assume d_value equals d_key | 512/8=64
self.Wq = nn.Linear(d_model, d_model) # query weights
self.Wk = nn.Linear(d_model, d_model) # key weights
self.Wv = nn.Linear(d_model, d_model) # value weights
self.Wo = nn.Linear(d_model, d_model) # output weights
self.dropout = nn.Dropout(p=dropout) # initialize dropout layer
def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Tensor = None):
"""
Args:
query: query vector (batch_size, q_length, d_model)
key: key vector (batch_size, k_length, d_model)
value: value vector (batch_size, s_length, d_model)
mask: mask for decoder
Returns:
output: attention values (batch_size, q_length, d_model)
attn_probs: softmax scores (batch_size, n_heads, q_length, k_length)
"""
batch_size = key.size(0)
# calculate query, key, and value tensors
Q = self.Wq(query) # (32, 10, 512) x (512, 512) = (32, 10, 512)
K = self.Wk(key) # (32, 10, 512) x (512, 512) = (32, 10, 512)
V = self.Wv(value) # (32, 10, 512) x (512, 512) = (32, 10, 512)
# split each tensor into n-heads to compute attention
# query tensor
Q = Q.view(batch_size, # (32, 10, 512) -> (32, 10, 8, 64)
-1, # -1 = q_length
self.n_heads,
self.d_key
).permute(0, 2, 1, 3) # (32, 10, 8, 64) -> (32, 8, 10, 64) = (batch_size, n_heads, q_length, d_key)
# key tensor
K = K.view(batch_size, # (32, 10, 512) -> (32, 10, 8, 64)
-1, # -1 = k_length
self.n_heads,
self.d_key
).permute(0, 2, 1, 3) # (32, 10, 8, 64) -> (32, 8, 10, 64) = (batch_size, n_heads, k_length, d_key)
# value tensor
V = V.view(batch_size, # (32, 10, 512) -> (32, 10, 8, 64)
-1, # -1 = v_length
self.n_heads,
self.d_key
).permute(0, 2, 1, 3) # (32, 10, 8, 64) -> (32, 8, 10, 64) = (batch_size, n_heads, v_length, d_key)
# computes attention
# scaled dot product -> QK^{T}
scaled_dot_prod = torch.matmul(Q,
# (32, 8, 10, 64) x (32, 8, 64, 10) -> (32, 8, 10, 10) = (batch_size, n_heads, q_length, k_length)
K.permute(0, 1, 3, 2)
) / math.sqrt(self.d_key) # sqrt(64)
# fill those positions of product as (-1e10) where mask positions are 0
if mask is not None:
scaled_dot_prod = scaled_dot_prod.masked_fill(mask == 0, -1e10)
# apply softmax
attn_probs = torch.softmax(scaled_dot_prod, dim=-1)
# multiply by values to get attention
A = torch.matmul(self.dropout(attn_probs), V) # (32, 8, 10, 10) x (32, 8, 10, 64) -> (32, 8, 10, 64)
# (batch_size, n_heads, q_length, k_length) x (batch_size, n_heads, v_length, d_key) -> (batch_size, n_heads, q_length, d_key)
# reshape attention back to (32, 10, 512)
A = A.permute(0, 2, 1, 3).contiguous() # (32, 8, 10, 64) -> (32, 10, 8, 64)
A = A.view(batch_size, -1,
self.n_heads * self.d_key) # (32, 10, 8, 64) -> (32, 10, 8*64) -> (32, 10, 512) = (batch_size, q_length, d_model)
# push through the final weight layer
output = self.Wo(A) # (32, 10, 512) x (512, 512) = (32, 10, 512)
return output, attn_probs # return attn_probs for visualization of the scores