Skip to content

Commit

Permalink
give a learned bias to and from registers for maxvit + register token…
Browse files Browse the repository at this point in the history
… variant
  • Loading branch information
lucidrains committed Oct 6, 2023
1 parent df8733d commit bbb24e3
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.5.2',
version = '1.5.3',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down
25 changes: 13 additions & 12 deletions vit_pytorch/max_vit_with_registers.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,11 @@ def __init__(
dim,
dim_head = 32,
dropout = 0.,
window_size = 7
window_size = 7,
num_registers = 1
):
super().__init__()
assert num_registers > 0
assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'

self.heads = dim // dim_head
Expand All @@ -142,7 +144,9 @@ def __init__(

# relative positional bias

self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
num_rel_pos_bias = (2 * window_size - 1) ** 2

self.rel_pos_bias = nn.Embedding(num_rel_pos_bias + 1, self.heads)

pos = torch.arange(window_size)
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
Expand All @@ -151,10 +155,11 @@ def __init__(
rel_pos += window_size - 1
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)

rel_pos_indices = F.pad(rel_pos_indices, (num_registers, 0, num_registers, 0), value = num_rel_pos_bias)
self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)

def forward(self, x):
device, h = x.device, self.heads
device, h, bias_indices = x.device, self.heads, self.rel_pos_indices

x = self.norm(x)

Expand All @@ -176,13 +181,8 @@ def forward(self, x):

# add positional bias

bias = self.rel_pos_bias(self.rel_pos_indices)
bias = rearrange(bias, 'i j h -> h i j')

num_registers = sim.shape[-1] - bias.shape[-1]
bias = F.pad(bias, (num_registers, 0, num_registers, 0), value = 0.)

sim = sim + bias
bias = self.rel_pos_bias(bias_indices)
sim = sim + rearrange(bias, 'i j h -> h i j')

# attention

Expand Down Expand Up @@ -215,6 +215,7 @@ def __init__(
):
super().__init__()
assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
assert num_register_tokens > 0

# convolutional stem

Expand Down Expand Up @@ -256,10 +257,10 @@ def __init__(
shrinkage_rate = mbconv_shrinkage_rate
)

block_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size)
block_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size, num_registers = num_register_tokens)
block_ff = FeedForward(dim = layer_dim, dropout = dropout)

grid_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size)
grid_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size, num_registers = num_register_tokens)
grid_ff = FeedForward(dim = layer_dim, dropout = dropout)

register_tokens = nn.Parameter(torch.randn(num_register_tokens, layer_dim))
Expand Down

0 comments on commit bbb24e3

Please sign in to comment.