Skip to content

Commit

Permalink
dynamic positional bias for crossformer the more efficient way as des…
Browse files Browse the repository at this point in the history
…cribed in appendix of paper
  • Loading branch information
lucidrains committed Nov 23, 2021
1 parent 36e32b7 commit b69b5af
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 9 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.24.0',
version = '0.24.1',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
Expand Down
28 changes: 20 additions & 8 deletions vit_pytorch/crossformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,25 @@ def __init__(
self.attn_type = attn_type
self.window_size = window_size

self.dpb = DynamicPositionBias(dim // 4)

self.norm = LayerNorm(dim)
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(inner_dim, dim, 1)

# positions

self.dpb = DynamicPositionBias(dim // 4)

# calculate and store indices for retrieving bias

pos = torch.arange(window_size)
grid = torch.stack(torch.meshgrid(pos, pos))
grid = rearrange(grid, 'c i j -> (i j) c')
rel_pos = grid[:, None] - grid[None, :]
rel_pos += window_size - 1
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)

self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)

def forward(self, x):
*_, height, width, heads, wsz, device = *x.shape, self.heads, self.window_size, x.device

Expand Down Expand Up @@ -136,12 +149,11 @@ def forward(self, x):

# add dynamic positional bias

i_pos = torch.arange(wsz, device = device)
j_pos = torch.arange(wsz, device = device)
grid = torch.stack(torch.meshgrid(i_pos, j_pos))
grid = rearrange(grid, 'c i j -> (i j) c')
rel_ij = grid[:, None] - grid[None, :]
rel_pos_bias = self.dpb(rel_ij.float())
pos = torch.arange(-wsz, wsz + 1, device = device)
rel_pos = torch.stack(torch.meshgrid(pos, pos))
rel_pos = rearrange(rel_pos, 'c i j -> (i j) c')
biases = self.dpb(rel_pos.float())
rel_pos_bias = biases[self.rel_pos_indices]

sim = sim + rel_pos_bias

Expand Down

0 comments on commit b69b5af

Please sign in to comment.