forked from lucidrains/vit-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
maxvit intent to build (lucidrains#211)
complete hybrid mbconv + block / grid efficient self attention MaxViT
- Loading branch information
1 parent
946b19b
commit c7bb5fc
Showing
7 changed files
with
317 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,270 @@ | ||
from functools import partial | ||
|
||
import torch | ||
from torch import nn, einsum | ||
|
||
from einops import rearrange, repeat | ||
from einops.layers.torch import Rearrange, Reduce | ||
|
||
# helpers | ||
|
||
def exists(val): | ||
return val is not None | ||
|
||
def default(val, d): | ||
return val if exists(val) else d | ||
|
||
def cast_tuple(val, length = 1): | ||
return val if isinstance(val, tuple) else ((val,) * length) | ||
|
||
# helper classes | ||
|
||
class PreNormResidual(nn.Module): | ||
def __init__(self, dim, fn): | ||
super().__init__() | ||
self.norm = nn.LayerNorm(dim) | ||
self.fn = fn | ||
|
||
def forward(self, x): | ||
return self.fn(self.norm(x)) + x | ||
|
||
# MBConv | ||
|
||
class SqueezeExcitation(nn.Module): | ||
def __init__(self, dim, shrinkage_rate = 0.25): | ||
super().__init__() | ||
hidden_dim = int(dim * shrinkage_rate) | ||
|
||
self.gate = nn.Sequential( | ||
Reduce('b c h w -> b c', 'mean'), | ||
nn.Linear(dim, hidden_dim, bias = False), | ||
nn.SiLU(), | ||
nn.Linear(hidden_dim, dim, bias = False), | ||
nn.Sigmoid(), | ||
Rearrange('b c -> b c 1 1') | ||
) | ||
|
||
def forward(self, x): | ||
return x * self.gate(x) | ||
|
||
|
||
class MBConvResidual(nn.Module): | ||
def __init__(self, fn, dropout = 0.): | ||
super().__init__() | ||
self.fn = fn | ||
self.dropsample = Dropsample(dropout) | ||
|
||
def forward(self, x): | ||
out = self.fn(x) | ||
out = self.dropsample(out) | ||
return out | ||
|
||
class Dropsample(nn.Module): | ||
def __init__(self, prob = 0): | ||
super().__init__() | ||
self.prob = prob | ||
|
||
def forward(self, x): | ||
device = x.device | ||
|
||
if self.prob == 0. or (not self.training): | ||
return x | ||
|
||
keep_mask = torch.FloatTensor((x.shape[0], 1, 1, 1), device = device).uniform_() > self.prob | ||
return x * keep_mask / (1 - self.prob) | ||
|
||
def MBConv( | ||
dim_in, | ||
dim_out, | ||
*, | ||
downsample, | ||
expansion_rate = 4, | ||
shrinkage_rate = 0.25, | ||
dropout = 0. | ||
): | ||
hidden_dim = int(expansion_rate * dim_out) | ||
stride = 2 if downsample else 1 | ||
|
||
net = nn.Sequential( | ||
nn.Conv2d(dim_in, dim_out, 1), | ||
nn.BatchNorm2d(dim_out), | ||
nn.SiLU(), | ||
nn.Conv2d(dim_out, dim_out, 3, stride = stride, padding = 1, groups = dim_out), | ||
SqueezeExcitation(dim_out, shrinkage_rate = shrinkage_rate), | ||
nn.Conv2d(dim_out, dim_out, 1), | ||
nn.BatchNorm2d(dim_out) | ||
) | ||
|
||
if dim_in == dim_out and not downsample: | ||
net = MBConvResidual(net, dropout = dropout) | ||
|
||
return net | ||
|
||
# attention related classes | ||
|
||
class Attention(nn.Module): | ||
def __init__( | ||
self, | ||
dim, | ||
dim_head = 32, | ||
dropout = 0., | ||
window_size = 7 | ||
): | ||
super().__init__() | ||
assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head' | ||
|
||
self.heads = dim // dim_head | ||
self.scale = dim_head ** -0.5 | ||
|
||
self.to_qkv = nn.Linear(dim, dim * 3, bias = False) | ||
|
||
self.attend = nn.Sequential( | ||
nn.Softmax(dim = -1), | ||
nn.Dropout(dropout) | ||
) | ||
|
||
self.to_out = nn.Sequential( | ||
nn.Linear(dim, dim, bias = False), | ||
nn.Dropout(dropout) | ||
) | ||
|
||
# relative positional bias | ||
|
||
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads) | ||
|
||
pos = torch.arange(window_size) | ||
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij')) | ||
grid = rearrange(grid, 'c i j -> (i j) c') | ||
rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...') | ||
rel_pos += window_size - 1 | ||
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1) | ||
|
||
self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False) | ||
|
||
def forward(self, x): | ||
batch, height, width, window_height, window_width, _, device, h = *x.shape, x.device, self.heads | ||
|
||
# flatten | ||
|
||
x = rearrange(x, 'b x y w1 w2 d -> (b x y) (w1 w2) d') | ||
|
||
# project for queries, keys, values | ||
|
||
q, k, v = self.to_qkv(x).chunk(3, dim = -1) | ||
|
||
# split heads | ||
|
||
q, k, v = map(lambda t: rearrange(t, 'b n (h d ) -> b h n d', h = h), (q, k, v)) | ||
|
||
# scale | ||
|
||
q = q * self.scale | ||
|
||
# sim | ||
|
||
sim = einsum('b h i d, b h j d -> b h i j', q, k) | ||
|
||
# add positional bias | ||
|
||
bias = self.rel_pos_bias(self.rel_pos_indices) | ||
sim = sim + rearrange(bias, 'i j h -> h i j') | ||
|
||
# attention | ||
|
||
attn = self.attend(sim) | ||
|
||
# aggregate | ||
|
||
out = einsum('b h i j, b h j d -> b h i d', attn, v) | ||
|
||
# merge heads | ||
|
||
out = rearrange(out, 'b h (w1 w2) d -> b w1 w2 (h d)', w1 = window_height, w2 = window_width) | ||
|
||
# combine heads out | ||
|
||
out = self.to_out(out) | ||
return rearrange(out, '(b x y) ... -> b x y ...', x = height, y = width) | ||
|
||
class MaxViT(nn.Module): | ||
def __init__( | ||
self, | ||
*, | ||
num_classes, | ||
dim, | ||
depth, | ||
dim_head = 32, | ||
dim_conv_stem = None, | ||
window_size = 7, | ||
mbconv_expansion_rate = 4, | ||
mbconv_shrinkage_rate = 0.25, | ||
dropout = 0.1, | ||
channels = 3 | ||
): | ||
super().__init__() | ||
assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage' | ||
|
||
# convolutional stem | ||
|
||
dim_conv_stem = default(dim_conv_stem, dim) | ||
|
||
self.conv_stem = nn.Sequential( | ||
nn.Conv2d(channels, dim_conv_stem, 3, stride = 2, padding = 1), | ||
nn.Conv2d(dim_conv_stem, dim_conv_stem, 3, padding = 1) | ||
) | ||
|
||
# variables | ||
|
||
num_stages = len(depth) | ||
|
||
dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages))) | ||
dims = (dim_conv_stem, *dims) | ||
dim_pairs = tuple(zip(dims[:-1], dims[1:])) | ||
|
||
self.layers = nn.ModuleList([]) | ||
|
||
# shorthand for window size for efficient block - grid like attention | ||
|
||
w = window_size | ||
|
||
# iterate through stages | ||
|
||
for ind, ((layer_dim_in, layer_dim), layer_depth) in enumerate(zip(dim_pairs, depth)): | ||
for stage_ind in range(layer_depth): | ||
is_first = stage_ind == 0 | ||
stage_dim_in = layer_dim_in if is_first else layer_dim | ||
|
||
block = nn.Sequential( | ||
MBConv( | ||
stage_dim_in, | ||
layer_dim, | ||
downsample = is_first, | ||
expansion_rate = mbconv_expansion_rate, | ||
shrinkage_rate = mbconv_shrinkage_rate | ||
), | ||
Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w), # block-like attention | ||
PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)), | ||
Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)'), | ||
|
||
Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w), # grid-like attention | ||
PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)), | ||
Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)'), | ||
) | ||
|
||
self.layers.append(block) | ||
|
||
# mlp head out | ||
|
||
self.mlp_head = nn.Sequential( | ||
Reduce('b d h w -> b d', 'mean'), | ||
nn.LayerNorm(dims[-1]), | ||
nn.Linear(dims[-1], num_classes) | ||
) | ||
|
||
def forward(self, x): | ||
x = self.conv_stem(x) | ||
|
||
for stage in self.layers: | ||
x = stage(x) | ||
|
||
return self.mlp_head(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters