Skip to content

Commit

Permalink
make it compatible to pytorch0.4
Browse files Browse the repository at this point in the history
  • Loading branch information
zhoubolei committed May 2, 2018
1 parent 6cf5ffb commit 18419ae
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions pytorch_CAM.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.nn import functional as F
import numpy as np
import cv2
import pdb

# input image
LABELS_URL = 'https://s3.amazonaws.com/outcome-blog/imagenet/labels.json'
Expand Down Expand Up @@ -58,7 +59,7 @@ def returnCAM(feature_conv, weight_softmax, class_idx):
std=[0.229, 0.224, 0.225]
)
preprocess = transforms.Compose([
transforms.Scale((224,224)),
transforms.Resize((224,224)),
transforms.ToTensor(),
normalize
])
Expand All @@ -75,8 +76,10 @@ def returnCAM(feature_conv, weight_softmax, class_idx):
classes = {int(key):value for (key, value)
in requests.get(LABELS_URL).json().items()}

h_x = F.softmax(logit).data.squeeze()
h_x = F.softmax(logit, dim=1).data.squeeze()
probs, idx = h_x.sort(0, True)
probs = probs.numpy()
idx = idx.numpy()

# output the prediction
for i in range(0, 5):
Expand Down

0 comments on commit 18419ae

Please sign in to comment.