From f2414b2c1b3772c141f015af03f5d69e31dde5be Mon Sep 17 00:00:00 2001 From: chinhsuanwu Date: Thu, 30 Dec 2021 05:52:23 +0800 Subject: [PATCH] Update MobileViT --- vit_pytorch/mobile_vit.py | 116 ++++++++++++++++++++------------------ 1 file changed, 62 insertions(+), 54 deletions(-) diff --git a/vit_pytorch/mobile_vit.py b/vit_pytorch/mobile_vit.py index b8b7253..34b933e 100644 --- a/vit_pytorch/mobile_vit.py +++ b/vit_pytorch/mobile_vit.py @@ -1,40 +1,27 @@ -""" -An implementation of MobileViT Model as defined in: -MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer -Arxiv: https://arxiv.org/abs/2110.02178 -Origin Code: https://github.com/murufeng/awesome_lightweight_networks -""" - import torch import torch.nn as nn from einops import rearrange from einops.layers.torch import Reduce -def _make_divisible(v, divisor, min_value=None): - if min_value is None: - min_value = divisor - new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) - if new_v < 0.9 * v: - new_v += divisor - return new_v - +# helpers -def conv_bn_relu(inp, oup, kernel, stride=1): +def conv_1x1_bn(inp, oup): return nn.Sequential( - nn.Conv2d(inp, oup, kernel_size=kernel, stride=stride, padding=1, bias=False), + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), - nn.ReLU6(inplace=True) + nn.SiLU() ) - -def conv_1x1_bn(inp, oup): +def conv_nxn_bn(inp, oup, kernal_size=3, stride=1): return nn.Sequential( - nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False), nn.BatchNorm2d(oup), - nn.ReLU6(inplace=True) + nn.SiLU() ) +# classes + class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() @@ -44,10 +31,11 @@ def __init__(self, dim, fn): def forward(self, x, **kwargs): return self.fn(self.norm(x), **kwargs) + class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout=0.): super().__init__() - self.ffn = nn.Sequential( + self.net = nn.Sequential( nn.Linear(dim, hidden_dim), nn.SiLU(), nn.Dropout(dropout), @@ -56,8 +44,7 @@ def __init__(self, dim, hidden_dim, dropout=0.): ) def forward(self, x): - return self.ffn(x) - + return self.net(x) class Attention(nn.Module): def __init__(self, dim, heads=8, dim_head=64, dropout=0.): @@ -76,7 +63,8 @@ def __init__(self, dim, heads=8, dim_head=64, dropout=0.): def forward(self, x): qkv = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv) + q, k, v = map(lambda t: rearrange( + t, 'b p n (h d) -> b p h n d', h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale attn = self.attend(dots) @@ -84,15 +72,19 @@ def forward(self, x): out = rearrange(out, 'b p h n d -> b p n (h d)') return self.to_out(out) - class Transformer(nn.Module): - def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + """Transformer block described in ViT. + Paper: https://arxiv.org/abs/2010.11929 + Based on: https://github.com/lucidrains/vit-pytorch + """ + + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ - PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), - PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + PreNorm(dim, Attention(dim, heads, dim_head, dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout)) ])) def forward(self, x): @@ -102,17 +94,24 @@ def forward(self, x): return x class MV2Block(nn.Module): - def __init__(self, inp, oup, stride=1, expand_ratio=4): - super(MV2Block, self).__init__() + """MV2 block described in MobileNetV2. + Paper: https://arxiv.org/pdf/1801.04381 + Based on: https://github.com/tonylins/pytorch-mobilenet-v2 + """ + + def __init__(self, inp, oup, stride=1, expansion=4): + super().__init__() + self.stride = stride assert stride in [1, 2] - hidden_dim = round(inp * expand_ratio) - self.identity = stride == 1 and inp == oup + hidden_dim = int(inp * expansion) + self.use_res_connect = self.stride == 1 and inp == oup - if expand_ratio == 1: + if expansion == 1: self.conv = nn.Sequential( # dw - nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), + nn.Conv2d(hidden_dim, hidden_dim, 3, stride, + 1, groups=hidden_dim, bias=False), nn.BatchNorm2d(hidden_dim), nn.SiLU(), # pw-linear @@ -126,7 +125,8 @@ def __init__(self, inp, oup, stride=1, expand_ratio=4): nn.BatchNorm2d(hidden_dim), nn.SiLU(), # dw - nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), + nn.Conv2d(hidden_dim, hidden_dim, 3, stride, + 1, groups=hidden_dim, bias=False), nn.BatchNorm2d(hidden_dim), nn.SiLU(), # pw-linear @@ -136,8 +136,7 @@ def __init__(self, inp, oup, stride=1, expand_ratio=4): def forward(self, x): out = self.conv(x) - - if self.identity: + if self.use_res_connect: out = out + x return out @@ -146,13 +145,13 @@ def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropou super().__init__() self.ph, self.pw = patch_size - self.conv1 = conv_bn_relu(channel, channel, kernel_size) + self.conv1 = conv_nxn_bn(channel, channel, kernel_size) self.conv2 = conv_1x1_bn(channel, dim) - self.transformer = Transformer(dim, depth, 1, 32, mlp_dim, dropout) + self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout) self.conv3 = conv_1x1_bn(dim, channel) - self.conv4 = conv_bn_relu(2 * channel, channel, kernel_size) + self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size) def forward(self, x): y = x.clone() @@ -163,9 +162,11 @@ def forward(self, x): # Global representations _, _, h, w = x.shape - x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw) + x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', + ph=self.ph, pw=self.pw) x = self.transformer(x) - x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph, pw=self.pw) + x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', + h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw) # Fusion x = self.conv3(x) @@ -173,18 +174,22 @@ def forward(self, x): x = self.conv4(x) return x - class MobileViT(nn.Module): + """MobileViT. + Paper: https://arxiv.org/abs/2110.02178 + Based on: https://github.com/chinhsuanwu/mobilevit-pytorch + """ + def __init__( self, image_size, dims, channels, num_classes, - expansion = 4, - kernel_size = 3, - patch_size = (2, 2), - depths = (2, 4, 3) + expansion=4, + kernel_size=3, + patch_size=(2, 2), + depths=(2, 4, 3) ): super().__init__() assert len(dims) == 3, 'dims must be a tuple of 3' @@ -196,28 +201,31 @@ def __init__( init_dim, *_, last_dim = channels - self.conv1 = conv_bn_relu(3, init_dim, kernel=3, stride=2) + self.conv1 = conv_nxn_bn(3, init_dim, stride=2) self.stem = nn.ModuleList([]) self.stem.append(MV2Block(channels[0], channels[1], 1, expansion)) self.stem.append(MV2Block(channels[1], channels[2], 2, expansion)) self.stem.append(MV2Block(channels[2], channels[3], 1, expansion)) self.stem.append(MV2Block(channels[2], channels[3], 1, expansion)) - + self.trunk = nn.ModuleList([]) self.trunk.append(nn.ModuleList([ MV2Block(channels[3], channels[4], 2, expansion), - MobileViTBlock(dims[0], depths[0], channels[5], kernel_size, patch_size, int(dims[0] * 2)) + MobileViTBlock(dims[0], depths[0], channels[5], + kernel_size, patch_size, int(dims[0] * 2)) ])) self.trunk.append(nn.ModuleList([ MV2Block(channels[5], channels[6], 2, expansion), - MobileViTBlock(dims[1], depths[1], channels[7], kernel_size, patch_size, int(dims[1] * 4)) + MobileViTBlock(dims[1], depths[1], channels[7], + kernel_size, patch_size, int(dims[1] * 4)) ])) self.trunk.append(nn.ModuleList([ MV2Block(channels[7], channels[8], 2, expansion), - MobileViTBlock(dims[2], depths[2], channels[9], kernel_size, patch_size, int(dims[2] * 4)) + MobileViTBlock(dims[2], depths[2], channels[9], + kernel_size, patch_size, int(dims[2] * 4)) ])) self.to_logits = nn.Sequential(