forked from spmallick/learnopencv
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ResNet18.py
38 lines (29 loc) · 972 Bytes
/
ResNet18.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from torchvision import models
from PIL import Image
import cv2
import torch
from torchsummary import summary
from torchvision import transforms
transform = transforms.Compose([ #[1]
# transforms.Resize(256), #[2]
# transforms.CenterCrop(224), #[3]
transforms.ToTensor(), #[4]
transforms.Normalize( #[5]
mean=[0.485, 0.456, 0.406], #[6]
std=[0.229, 0.224, 0.225] #[7]
)])
with open('imagenet_classes.txt') as f:
labels = [line.strip() for line in f.readlines()]
dir(models)
img = Image.open("camel.jpg")
img_t = transform(img)
batch_t = torch.unsqueeze(img_t, 0)
# First, load the model
resnet = models.resnet18(pretrained=True)
summary(resnet, (3, 224,224))
# Second, put the network in eval mode
resnet.eval()
# Third, carry out model inference
preds = resnet(batch_t)
pred, class_idx = torch.max(preds, dim=1)
print(labels[class_idx])