forked from huggingface/diffusers
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathattention.py
146 lines (114 loc) · 4.84 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
import torch
def split_einsum(q, k, v, mask, heads, dim_head):
""" Attention Implementation backing AttentionImplementations.SPLIT_EINSUM
- Implements https://machinelearning.apple.com/research/neural-engine-transformers
- Recommended for ANE
- Marginally slower on GPU
"""
mh_q = [
q[:, head_idx * dim_head:(head_idx + 1) *
dim_head, :, :] for head_idx in range(heads)
] # (bs, dim_head, 1, max_seq_length) * heads
k = k.transpose(1, 3)
mh_k = [
k[:, :, :,
head_idx * dim_head:(head_idx + 1) * dim_head]
for head_idx in range(heads)
] # (bs, max_seq_length, 1, dim_head) * heads
mh_v = [
v[:, head_idx * dim_head:(head_idx + 1) *
dim_head, :, :] for head_idx in range(heads)
] # (bs, dim_head, 1, max_seq_length) * heads
attn_weights = [
torch.einsum("bchq,bkhc->bkhq", [qi, ki]) * (dim_head**-0.5)
for qi, ki in zip(mh_q, mh_k)
] # (bs, max_seq_length, 1, max_seq_length) * heads
if mask is not None:
for head_idx in range(heads):
attn_weights[head_idx] = attn_weights[head_idx] + mask
attn_weights = [
aw.softmax(dim=1) for aw in attn_weights
] # (bs, max_seq_length, 1, max_seq_length) * heads
attn = [
torch.einsum("bkhq,bchk->bchq", wi, vi)
for wi, vi in zip(attn_weights, mh_v)
] # (bs, dim_head, 1, max_seq_length) * heads
attn = torch.cat(attn, dim=1) # (bs, dim, 1, max_seq_length)
return attn
CHUNK_SIZE = 512
def split_einsum_v2(q, k, v, mask, heads, dim_head):
""" Attention Implementation backing AttentionImplementations.SPLIT_EINSUM_V2
- Implements https://machinelearning.apple.com/research/neural-engine-transformers
- Recommended for ANE
- Marginally slower on GPU
- Chunks the query sequence to avoid large intermediate tensors and improves ANE performance
"""
query_seq_length = q.size(3)
num_chunks = query_seq_length // CHUNK_SIZE
if num_chunks == 0:
logger.info(
"AttentionImplementations.SPLIT_EINSUM_V2: query sequence too short to chunk "
f"({query_seq_length}<{CHUNK_SIZE}), fall back to AttentionImplementations.SPLIT_EINSUM (safe to ignore)")
return split_einsum(q, k, v, mask, heads, dim_head)
logger.info(
"AttentionImplementations.SPLIT_EINSUM_V2: Splitting query sequence length of "
f"{query_seq_length} into {num_chunks} chunks")
mh_q = [
q[:, head_idx * dim_head:(head_idx + 1) *
dim_head, :, :] for head_idx in range(heads)
] # (bs, dim_head, 1, max_seq_length) * heads
# Chunk the query sequence for each head
mh_q_chunked = [
[h_q[..., chunk_idx * CHUNK_SIZE:(chunk_idx + 1) * CHUNK_SIZE] for chunk_idx in range(num_chunks)]
for h_q in mh_q
] # ((bs, dim_head, 1, QUERY_SEQ_CHUNK_SIZE) * num_chunks) * heads
k = k.transpose(1, 3)
mh_k = [
k[:, :, :,
head_idx * dim_head:(head_idx + 1) * dim_head]
for head_idx in range(heads)
] # (bs, max_seq_length, 1, dim_head) * heads
mh_v = [
v[:, head_idx * dim_head:(head_idx + 1) *
dim_head, :, :] for head_idx in range(heads)
] # (bs, dim_head, 1, max_seq_length) * heads
attn_weights = [
[
torch.einsum("bchq,bkhc->bkhq", [qi_chunk, ki]) * (dim_head**-0.5)
for qi_chunk in h_q_chunked
] for h_q_chunked, ki in zip(mh_q_chunked, mh_k)
] # ((bs, max_seq_length, 1, chunk_size) * num_chunks) * heads
attn_weights = [
[aw_chunk.softmax(dim=1) for aw_chunk in aw_chunked]
for aw_chunked in attn_weights
] # ((bs, max_seq_length, 1, chunk_size) * num_chunks) * heads
attn = [
[
torch.einsum("bkhq,bchk->bchq", wi_chunk, vi)
for wi_chunk in wi_chunked
] for wi_chunked, vi in zip(attn_weights, mh_v)
] # ((bs, dim_head, 1, chunk_size) * num_chunks) * heads
attn = torch.cat([
torch.cat(attn_chunked, dim=3) for attn_chunked in attn
], dim=1) # (bs, dim, 1, max_seq_length)
return attn
def original(q, k, v, mask, heads, dim_head):
""" Attention Implementation backing AttentionImplementations.ORIGINAL
- Not recommended for ANE
- Recommended for GPU
"""
bs = q.size(0)
mh_q = q.view(bs, heads, dim_head, -1)
mh_k = k.view(bs, heads, dim_head, -1)
mh_v = v.view(bs, heads, dim_head, -1)
attn_weights = torch.einsum("bhcq,bhck->bhqk", [mh_q, mh_k])
attn_weights.mul_(dim_head**-0.5)
if mask is not None:
attn_weights = attn_weights + mask
attn_weights = attn_weights.softmax(dim=3)
attn = torch.einsum("bhqk,bhck->bhcq", [attn_weights, mh_v])
attn = attn.contiguous().view(bs, heads * dim_head, 1, -1)
return attn