Skip to content

Commit

Permalink
assert minimum number of patches
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 16, 2020
1 parent c7b74e0 commit f7c164d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.2.2',
version = '0.2.3',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
Expand Down
3 changes: 3 additions & 0 deletions vit_pytorch/vit_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from einops import rearrange
from torch import nn

MIN_NUM_PATCHES = 16

class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
Expand Down Expand Up @@ -85,6 +87,7 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ml
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective. try decreasing your patch size'

self.patch_size = patch_size

Expand Down

0 comments on commit f7c164d

Please sign in to comment.