Skip to content

Commit

Permalink
Update MobileViT
Browse files Browse the repository at this point in the history
  • Loading branch information
chinhsuanwu committed Dec 29, 2021
1 parent 891b92e commit f2414b2
Showing 1 changed file with 62 additions and 54 deletions.
116 changes: 62 additions & 54 deletions vit_pytorch/mobile_vit.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,27 @@
"""
An implementation of MobileViT Model as defined in:
MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer
Arxiv: https://arxiv.org/abs/2110.02178
Origin Code: https://github.com/murufeng/awesome_lightweight_networks
"""

import torch
import torch.nn as nn

from einops import rearrange
from einops.layers.torch import Reduce

def _make_divisible(v, divisor, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v

# helpers

def conv_bn_relu(inp, oup, kernel, stride=1):
def conv_1x1_bn(inp, oup):
return nn.Sequential(
nn.Conv2d(inp, oup, kernel_size=kernel, stride=stride, padding=1, bias=False),
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU6(inplace=True)
nn.SiLU()
)


def conv_1x1_bn(inp, oup):
def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU6(inplace=True)
nn.SiLU()
)

# classes

class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
Expand All @@ -44,10 +31,11 @@ def __init__(self, dim, fn):
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)


class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.ffn = nn.Sequential(
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.SiLU(),
nn.Dropout(dropout),
Expand All @@ -56,8 +44,7 @@ def __init__(self, dim, hidden_dim, dropout=0.):
)

def forward(self, x):
return self.ffn(x)

return self.net(x)

class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
Expand All @@ -76,23 +63,28 @@ def __init__(self, dim, heads=8, dim_head=64, dropout=0.):

def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
q, k, v = map(lambda t: rearrange(
t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)

dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b p h n d -> b p n (h d)')
return self.to_out(out)


class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
"""Transformer block described in ViT.
Paper: https://arxiv.org/abs/2010.11929
Based on: https://github.com/lucidrains/vit-pytorch
"""

def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
]))

def forward(self, x):
Expand All @@ -102,17 +94,24 @@ def forward(self, x):
return x

class MV2Block(nn.Module):
def __init__(self, inp, oup, stride=1, expand_ratio=4):
super(MV2Block, self).__init__()
"""MV2 block described in MobileNetV2.
Paper: https://arxiv.org/pdf/1801.04381
Based on: https://github.com/tonylins/pytorch-mobilenet-v2
"""

def __init__(self, inp, oup, stride=1, expansion=4):
super().__init__()
self.stride = stride
assert stride in [1, 2]

hidden_dim = round(inp * expand_ratio)
self.identity = stride == 1 and inp == oup
hidden_dim = int(inp * expansion)
self.use_res_connect = self.stride == 1 and inp == oup

if expand_ratio == 1:
if expansion == 1:
self.conv = nn.Sequential(
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
# pw-linear
Expand All @@ -126,7 +125,8 @@ def __init__(self, inp, oup, stride=1, expand_ratio=4):
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
# pw-linear
Expand All @@ -136,8 +136,7 @@ def __init__(self, inp, oup, stride=1, expand_ratio=4):

def forward(self, x):
out = self.conv(x)

if self.identity:
if self.use_res_connect:
out = out + x
return out

Expand All @@ -146,13 +145,13 @@ def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropou
super().__init__()
self.ph, self.pw = patch_size

self.conv1 = conv_bn_relu(channel, channel, kernel_size)
self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
self.conv2 = conv_1x1_bn(channel, dim)

self.transformer = Transformer(dim, depth, 1, 32, mlp_dim, dropout)
self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)

self.conv3 = conv_1x1_bn(dim, channel)
self.conv4 = conv_bn_relu(2 * channel, channel, kernel_size)
self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)

def forward(self, x):
y = x.clone()
Expand All @@ -163,28 +162,34 @@ def forward(self, x):

# Global representations
_, _, h, w = x.shape
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d',
ph=self.ph, pw=self.pw)
x = self.transformer(x)
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph, pw=self.pw)
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)',
h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)

# Fusion
x = self.conv3(x)
x = torch.cat((x, y), 1)
x = self.conv4(x)
return x


class MobileViT(nn.Module):
"""MobileViT.
Paper: https://arxiv.org/abs/2110.02178
Based on: https://github.com/chinhsuanwu/mobilevit-pytorch
"""

def __init__(
self,
image_size,
dims,
channels,
num_classes,
expansion = 4,
kernel_size = 3,
patch_size = (2, 2),
depths = (2, 4, 3)
expansion=4,
kernel_size=3,
patch_size=(2, 2),
depths=(2, 4, 3)
):
super().__init__()
assert len(dims) == 3, 'dims must be a tuple of 3'
Expand All @@ -196,28 +201,31 @@ def __init__(

init_dim, *_, last_dim = channels

self.conv1 = conv_bn_relu(3, init_dim, kernel=3, stride=2)
self.conv1 = conv_nxn_bn(3, init_dim, stride=2)

self.stem = nn.ModuleList([])
self.stem.append(MV2Block(channels[0], channels[1], 1, expansion))
self.stem.append(MV2Block(channels[1], channels[2], 2, expansion))
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))

self.trunk = nn.ModuleList([])
self.trunk.append(nn.ModuleList([
MV2Block(channels[3], channels[4], 2, expansion),
MobileViTBlock(dims[0], depths[0], channels[5], kernel_size, patch_size, int(dims[0] * 2))
MobileViTBlock(dims[0], depths[0], channels[5],
kernel_size, patch_size, int(dims[0] * 2))
]))

self.trunk.append(nn.ModuleList([
MV2Block(channels[5], channels[6], 2, expansion),
MobileViTBlock(dims[1], depths[1], channels[7], kernel_size, patch_size, int(dims[1] * 4))
MobileViTBlock(dims[1], depths[1], channels[7],
kernel_size, patch_size, int(dims[1] * 4))
]))

self.trunk.append(nn.ModuleList([
MV2Block(channels[7], channels[8], 2, expansion),
MobileViTBlock(dims[2], depths[2], channels[9], kernel_size, patch_size, int(dims[2] * 4))
MobileViTBlock(dims[2], depths[2], channels[9],
kernel_size, patch_size, int(dims[2] * 4))
]))

self.to_logits = nn.Sequential(
Expand Down

0 comments on commit f2414b2

Please sign in to comment.