This is just a quick way to load Swin Transformers from image classification from PyTorch Hub. This repository makes it possible to load Swin Transformers in 1 line of code.
The official Swin transformer repository can be found here:
https://github.com/microsoft/Swin-Transformer
torch
- PyTorchtimm
- Torchvision Image Models
import torch
HUB_URL = "sayhi123/swin-transformer-hub"
MODEL_NAME = "swin_tiny_patch4_window7_224"
# check hubconf for more models.
model = torch.hub.load(HUB_URL, MODEL_NAME, pretrained=True) # load from torch hub
Transforms for passing in PIL
images for inference.
from torchvision import transforms as T
from PIL import Image
import timm
transform = T.Compose([
T.Resize(224),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(timm.data.IMAGENET_DEFAULT_MEAN, timm.data.IMAGENET_DEFAULT_STD)
])
Get a list of imagenet labels.
import json
from urllib.request import urlopen
URL = "https://raw.githubusercontent.com/sayhi123/swin-transformer-hub/main/imagenet_labels.json"
response = urlopen(URL)
classes = json.loads(response.read())
len(classes) # Should return 1000
- Add support for more model weights