forked from lucidrains/vit-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b483b16
commit daf3abb
Showing
4 changed files
with
215 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
import torch | ||
from torch import nn, einsum | ||
|
||
from einops import rearrange | ||
from einops.layers.torch import Rearrange, Reduce | ||
|
||
# helpers | ||
|
||
def cast_tuple(val, depth): | ||
return val if isinstance(val, tuple) else ((val,) * depth) | ||
|
||
# classes | ||
|
||
class ChanNorm(nn.Module): | ||
def __init__(self, dim, eps = 1e-5): | ||
super().__init__() | ||
self.eps = eps | ||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) | ||
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) | ||
|
||
def forward(self, x): | ||
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt() | ||
mean = torch.mean(x, dim = 1, keepdim = True) | ||
return (x - mean) / (std + self.eps) * self.g + self.b | ||
|
||
class PreNorm(nn.Module): | ||
def __init__(self, dim, fn): | ||
super().__init__() | ||
self.norm = ChanNorm(dim) | ||
self.fn = fn | ||
|
||
def forward(self, x, **kwargs): | ||
return self.fn(self.norm(x), **kwargs) | ||
|
||
class FeedForward(nn.Module): | ||
def __init__(self, dim, mlp_mult = 4, dropout = 0.): | ||
super().__init__() | ||
self.net = nn.Sequential( | ||
nn.Conv2d(dim, dim * mlp_mult, 1), | ||
nn.GELU(), | ||
nn.Dropout(dropout), | ||
nn.Conv2d(dim * mlp_mult, dim, 1), | ||
nn.Dropout(dropout) | ||
) | ||
def forward(self, x): | ||
return self.net(x) | ||
|
||
class Attention(nn.Module): | ||
def __init__(self, dim, heads = 8, dropout = 0.): | ||
super().__init__() | ||
assert (dim % heads) == 0, 'dimension must be divisible by number of heads' | ||
dim_head = dim // heads | ||
self.heads = heads | ||
self.scale = dim_head ** -0.5 | ||
|
||
self.attend = nn.Softmax(dim = -1) | ||
self.to_qkv = nn.Conv2d(dim, dim * 3, 1, bias = False) | ||
|
||
self.to_out = nn.Sequential( | ||
nn.Conv2d(dim, dim, 1), | ||
nn.Dropout(dropout) | ||
) | ||
|
||
def forward(self, x): | ||
b, c, h, w, heads = *x.shape, self.heads | ||
|
||
qkv = self.to_qkv(x).chunk(3, dim = 1) | ||
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h = heads), qkv) | ||
|
||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale | ||
|
||
attn = self.attend(dots) | ||
|
||
out = einsum('b h i j, b h j d -> b h i d', attn, v) | ||
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w) | ||
return self.to_out(out) | ||
|
||
def Aggregate(dim, dim_out): | ||
return nn.Sequential( | ||
nn.Conv2d(dim, dim_out, 3, padding = 1), | ||
ChanNorm(dim_out), | ||
nn.MaxPool2d(2) | ||
) | ||
|
||
class Transformer(nn.Module): | ||
def __init__(self, dim, seq_len, depth, heads, mlp_mult, dropout = 0.): | ||
super().__init__() | ||
self.layers = nn.ModuleList([]) | ||
self.pos_emb = nn.Parameter(torch.randn(seq_len)) | ||
|
||
for _ in range(depth): | ||
self.layers.append(nn.ModuleList([ | ||
PreNorm(dim, Attention(dim, heads = heads, dropout = dropout)), | ||
PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout)) | ||
])) | ||
def forward(self, x): | ||
*_, h, w = x.shape | ||
|
||
pos_emb = self.pos_emb[:(h * w)] | ||
pos_emb = rearrange(pos_emb, '(h w) -> () () h w', h = h, w = w) | ||
x = x + pos_emb | ||
|
||
for attn, ff in self.layers: | ||
x = attn(x) + x | ||
x = ff(x) + x | ||
return x | ||
|
||
class NesT(nn.Module): | ||
def __init__( | ||
self, | ||
*, | ||
image_size, | ||
patch_size, | ||
num_classes, | ||
dim, | ||
heads, | ||
num_heirarchies, | ||
block_repeats, | ||
mlp_mult = 4, | ||
channels = 3, | ||
dim_head = 64, | ||
dropout = 0. | ||
): | ||
super().__init__() | ||
assert (image_size % patch_size) == 0, 'Image dimensions must be divisible by the patch size.' | ||
num_patches = (image_size // patch_size) ** 2 | ||
patch_dim = channels * patch_size ** 2 | ||
fmap_size = image_size // patch_size | ||
blocks = 2 ** (num_heirarchies - 1) | ||
|
||
seq_len = (fmap_size // blocks) ** 2 # sequence length is held constant across heirarchy | ||
mults = [2 ** i for i in reversed(range(num_heirarchies))] | ||
|
||
layer_heads = list(map(lambda t: t * heads, mults)) | ||
layer_dims = list(map(lambda t: t * dim, mults)) | ||
|
||
layer_dims = [*layer_dims, layer_dims[-1]] | ||
dim_pairs = zip(layer_dims[:-1], layer_dims[1:]) | ||
|
||
self.to_patch_embedding = nn.Sequential( | ||
Rearrange('b c (h p1) (w p2) -> b (p1 p2 c) h w', p1 = patch_size, p2 = patch_size), | ||
nn.Conv2d(patch_dim, layer_dims[0], 1), | ||
) | ||
|
||
block_repeats = cast_tuple(block_repeats, num_heirarchies) | ||
|
||
self.layers = nn.ModuleList([]) | ||
|
||
for level, heads, (dim_in, dim_out), block_repeat in zip(reversed(range(num_heirarchies)), layer_heads, dim_pairs, block_repeats): | ||
is_last = level == 0 | ||
depth = block_repeat | ||
|
||
self.layers.append(nn.ModuleList([ | ||
Transformer(dim_in, seq_len, depth, heads, mlp_mult, dropout), | ||
Aggregate(dim_in, dim_out) if not is_last else nn.Identity() | ||
])) | ||
|
||
self.mlp_head = nn.Sequential( | ||
ChanNorm(dim), | ||
Reduce('b c h w -> b c', 'mean'), | ||
nn.Linear(dim, num_classes) | ||
) | ||
|
||
def forward(self, img): | ||
x = self.to_patch_embedding(img) | ||
b, c, h, w = x.shape | ||
|
||
num_heirarchies = len(self.layers) | ||
|
||
for level, (transformer, aggregate) in zip(reversed(range(num_heirarchies)), self.layers): | ||
block_size = 2 ** level | ||
x = rearrange(x, 'b c (b1 h) (b2 w) -> (b b1 b2) c h w', b1 = block_size, b2 = block_size) | ||
x = transformer(x) | ||
x = rearrange(x, '(b b1 b2) c h w -> b c (b1 h) (b2 w)', b1 = block_size, b2 = block_size) | ||
x = aggregate(x) | ||
|
||
return self.mlp_head(x) |