diff --git a/README.md b/README.md index 2784326..daa7660 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ preds = v(img) # (1, 1000) ``` ## Parameters + - `image_size`: int. Image size. If you have rectangular images, make sure your image size is the maximum of the width and height - `patch_size`: int. @@ -583,6 +584,35 @@ img = torch.randn(1, 3, 224, 224) v(img) # (1, 1000) ``` +## FAQ + +- How do I pass in non-square images? + +You can already pass in non-square images - you just have to make sure your height and width is less than or equal to the `image_size`, and both divisible by the `patch_size` + +ex. + +```python +import torch +from vit_pytorch import ViT + +v = ViT( + image_size = 256, + patch_size = 32, + num_classes = 1000, + dim = 1024, + depth = 6, + heads = 16, + mlp_dim = 2048, + dropout = 0.1, + emb_dropout = 0.1 +) + +img = torch.randn(1, 3, 256, 128) # <-- not a square + +preds = v(img) # (1, 1000) +``` + ## Resources Coming from computer vision and new to transformers? Here are some resources that greatly accelerated my learning.