Skip to content

Commit

Permalink
Update vit_pytorch.py
Browse files Browse the repository at this point in the history
  • Loading branch information
minhlong94 authored Nov 21, 2020
1 parent 6c8dfc1 commit ee5e4e9
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions vit_pytorch/vit_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ def forward(self, x, mask = None):
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dropout = 0., emb_dropout = 0.):
super().__init__()
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective. try decreasing your patch size'
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'

self.patch_size = patch_size

Expand Down

0 comments on commit ee5e4e9

Please sign in to comment.