Skip to content

Commit

Permalink
Change map_location to CPU, and fix typos
Browse files Browse the repository at this point in the history
  • Loading branch information
ekzhang committed Aug 10, 2020
1 parent 9b83fbb commit f1a7624
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,18 +154,18 @@ model.eval()
# Run inference on an image
img = Image.open('city_1.png')
labels = model.predict_one(img) # returns a NumPy array containing integer labels
assert labels.shape == (224, 224)
assert labels.shape == (1024, 2048)

# Run inference on a batch of images
img2 = Image.open('city_2.png')
batch_labels = model.predict([img, img2]) # returns a NumPy array containing integer labels
assert batch_labels.shape == (2, 224, 224)
assert batch_labels.shape == (2, 1024, 2048)

# Run inference directly
dummy_input = torch.randn(1, 3, 224, 224, device='cuda')
# Run forward pass directly
dummy_input = torch.randn(1, 3, 1024, 2048, device='cuda')
with torch.no_grad():
dummy_output = model(dummy_input)
assert dummy_output.shape == (1, 19, 224, 224)
assert dummy_output.shape == (1, 19, 1024, 2048)
```

In addition, you can generate colorized and composited versions of the label masks as human-interpretable images.
Expand Down
4 changes: 2 additions & 2 deletions fastseg/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ def from_pretrained(cls, filename=None, **kwargs):
name = cls.model_name
if name in MODEL_WEIGHTS_URL:
weights_url = MODEL_WEIGHTS_URL[name]
checkpoint = torch.hub.load_state_dict_from_url(weights_url)
checkpoint = torch.hub.load_state_dict_from_url(weights_url, map_location='cpu')
else:
raise ValueError(f'pretrained weights not found for model {name}, please specify a checkpoint')
else:
checkpoint = torch.load(filename)
checkpoint = torch.load(filename, map_location='cpu')
net = cls(checkpoint['num_classes'], **kwargs)
net.load_checkpoint(checkpoint)
return net
Expand Down

0 comments on commit f1a7624

Please sign in to comment.