Skip to content

Commit

Permalink
fix typo
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 1, 2021
1 parent 26df10c commit a254a02
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ pred = model(img) # (1, 1000)

<img src="./images/nest.png" width="400px"></img>

This <a href="https://arxiv.org/abs/2105.12723">paper</a> decided to process the image in heirarchical stages, with attention only within tokens of local blocks, which aggregate as it moves up the heirarchy. The aggregation is done in the image plane, and contains a convolution to allow it to pass information across the boundary.
This <a href="https://arxiv.org/abs/2105.12723">paper</a> decided to process the image in hierarchical stages, with attention only within tokens of local blocks, which aggregate as it moves up the heirarchy. The aggregation is done in the image plane, and contains a convolution and subsequent maxpool to allow it to pass information across the boundary.

You can use it with the following code (ex. NesT-T)

Expand All @@ -395,7 +395,7 @@ nest = NesT(
patch_size = 4,
dim = 96,
heads = 3,
num_heirarchies = 3, # number of heirarchies
num_hierarchies = 3, # number of hierarchies
block_repeats = (8, 4, 1), # the number of transformer blocks at each heirarchy, starting from the bottom
num_classes = 1000
)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.19.2',
version = '0.19.3',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
Expand Down
16 changes: 8 additions & 8 deletions vit_pytorch/nest.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(
num_classes,
dim,
heads,
num_heirarchies,
num_hierarchies,
block_repeats,
mlp_mult = 4,
channels = 3,
Expand All @@ -126,11 +126,11 @@ def __init__(
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
fmap_size = image_size // patch_size
blocks = 2 ** (num_heirarchies - 1)
blocks = 2 ** (num_hierarchies - 1)

seq_len = (fmap_size // blocks) ** 2 # sequence length is held constant across heirarchy
heirarchies = list(reversed(range(num_heirarchies)))
mults = [2 ** i for i in heirarchies]
hierarchies = list(reversed(range(num_hierarchies)))
mults = [2 ** i for i in hierarchies]

layer_heads = list(map(lambda t: t * heads, mults))
layer_dims = list(map(lambda t: t * dim, mults))
Expand All @@ -143,11 +143,11 @@ def __init__(
nn.Conv2d(patch_dim, layer_dims[0], 1),
)

block_repeats = cast_tuple(block_repeats, num_heirarchies)
block_repeats = cast_tuple(block_repeats, num_hierarchies)

self.layers = nn.ModuleList([])

for level, heads, (dim_in, dim_out), block_repeat in zip(heirarchies, layer_heads, dim_pairs, block_repeats):
for level, heads, (dim_in, dim_out), block_repeat in zip(hierarchies, layer_heads, dim_pairs, block_repeats):
is_last = level == 0
depth = block_repeat

Expand All @@ -166,9 +166,9 @@ def forward(self, img):
x = self.to_patch_embedding(img)
b, c, h, w = x.shape

num_heirarchies = len(self.layers)
num_hierarchies = len(self.layers)

for level, (transformer, aggregate) in zip(reversed(range(num_heirarchies)), self.layers):
for level, (transformer, aggregate) in zip(reversed(range(num_hierarchies)), self.layers):
block_size = 2 ** level
x = rearrange(x, 'b c (b1 h) (b2 w) -> (b b1 b2) c h w', b1 = block_size, b2 = block_size)
x = transformer(x)
Expand Down

0 comments on commit a254a02

Please sign in to comment.