forked from lucidrains/vit-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmax_vit.py
288 lines (212 loc) · 8.11 KB
/
max_vit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
from functools import partial
import torch
from torch import nn, einsum
from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
# helper classes
class PreNormResidual(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x):
return self.fn(self.norm(x)) + x
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# MBConv
class SqueezeExcitation(nn.Module):
def __init__(self, dim, shrinkage_rate = 0.25):
super().__init__()
hidden_dim = int(dim * shrinkage_rate)
self.gate = nn.Sequential(
Reduce('b c h w -> b c', 'mean'),
nn.Linear(dim, hidden_dim, bias = False),
nn.SiLU(),
nn.Linear(hidden_dim, dim, bias = False),
nn.Sigmoid(),
Rearrange('b c -> b c 1 1')
)
def forward(self, x):
return x * self.gate(x)
class MBConvResidual(nn.Module):
def __init__(self, fn, dropout = 0.):
super().__init__()
self.fn = fn
self.dropsample = Dropsample(dropout)
def forward(self, x):
out = self.fn(x)
out = self.dropsample(out)
return out + x
class Dropsample(nn.Module):
def __init__(self, prob = 0):
super().__init__()
self.prob = prob
def forward(self, x):
device = x.device
if self.prob == 0. or (not self.training):
return x
keep_mask = torch.FloatTensor((x.shape[0], 1, 1, 1), device = device).uniform_() > self.prob
return x * keep_mask / (1 - self.prob)
def MBConv(
dim_in,
dim_out,
*,
downsample,
expansion_rate = 4,
shrinkage_rate = 0.25,
dropout = 0.
):
hidden_dim = int(expansion_rate * dim_out)
stride = 2 if downsample else 1
net = nn.Sequential(
nn.Conv2d(dim_in, hidden_dim, 1),
nn.BatchNorm2d(hidden_dim),
nn.GELU(),
nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = hidden_dim),
nn.BatchNorm2d(hidden_dim),
nn.GELU(),
SqueezeExcitation(hidden_dim, shrinkage_rate = shrinkage_rate),
nn.Conv2d(hidden_dim, dim_out, 1),
nn.BatchNorm2d(dim_out)
)
if dim_in == dim_out and not downsample:
net = MBConvResidual(net, dropout = dropout)
return net
# attention related classes
class Attention(nn.Module):
def __init__(
self,
dim,
dim_head = 32,
dropout = 0.,
window_size = 7
):
super().__init__()
assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'
self.heads = dim // dim_head
self.scale = dim_head ** -0.5
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
self.attend = nn.Sequential(
nn.Softmax(dim = -1),
nn.Dropout(dropout)
)
self.to_out = nn.Sequential(
nn.Linear(dim, dim, bias = False),
nn.Dropout(dropout)
)
# relative positional bias
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
pos = torch.arange(window_size)
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
grid = rearrange(grid, 'c i j -> (i j) c')
rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
rel_pos += window_size - 1
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)
self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)
def forward(self, x):
batch, height, width, window_height, window_width, _, device, h = *x.shape, x.device, self.heads
# flatten
x = rearrange(x, 'b x y w1 w2 d -> (b x y) (w1 w2) d')
# project for queries, keys, values
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
# split heads
q, k, v = map(lambda t: rearrange(t, 'b n (h d ) -> b h n d', h = h), (q, k, v))
# scale
q = q * self.scale
# sim
sim = einsum('b h i d, b h j d -> b h i j', q, k)
# add positional bias
bias = self.rel_pos_bias(self.rel_pos_indices)
sim = sim + rearrange(bias, 'i j h -> h i j')
# attention
attn = self.attend(sim)
# aggregate
out = einsum('b h i j, b h j d -> b h i d', attn, v)
# merge heads
out = rearrange(out, 'b h (w1 w2) d -> b w1 w2 (h d)', w1 = window_height, w2 = window_width)
# combine heads out
out = self.to_out(out)
return rearrange(out, '(b x y) ... -> b x y ...', x = height, y = width)
class MaxViT(nn.Module):
def __init__(
self,
*,
num_classes,
dim,
depth,
dim_head = 32,
dim_conv_stem = None,
window_size = 7,
mbconv_expansion_rate = 4,
mbconv_shrinkage_rate = 0.25,
dropout = 0.1,
channels = 3
):
super().__init__()
assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
# convolutional stem
dim_conv_stem = default(dim_conv_stem, dim)
self.conv_stem = nn.Sequential(
nn.Conv2d(channels, dim_conv_stem, 3, stride = 2, padding = 1),
nn.Conv2d(dim_conv_stem, dim_conv_stem, 3, padding = 1)
)
# variables
num_stages = len(depth)
dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
dims = (dim_conv_stem, *dims)
dim_pairs = tuple(zip(dims[:-1], dims[1:]))
self.layers = nn.ModuleList([])
# shorthand for window size for efficient block - grid like attention
w = window_size
# iterate through stages
for ind, ((layer_dim_in, layer_dim), layer_depth) in enumerate(zip(dim_pairs, depth)):
for stage_ind in range(layer_depth):
is_first = stage_ind == 0
stage_dim_in = layer_dim_in if is_first else layer_dim
block = nn.Sequential(
MBConv(
stage_dim_in,
layer_dim,
downsample = is_first,
expansion_rate = mbconv_expansion_rate,
shrinkage_rate = mbconv_shrinkage_rate
),
Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w), # block-like attention
PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)'),
Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w), # grid-like attention
PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)'),
)
self.layers.append(block)
# mlp head out
self.mlp_head = nn.Sequential(
Reduce('b d h w -> b d', 'mean'),
nn.LayerNorm(dims[-1]),
nn.Linear(dims[-1], num_classes)
)
def forward(self, x):
x = self.conv_stem(x)
for stage in self.layers:
x = stage(x)
return self.mlp_head(x)