Skip to content

Commit

Permalink
compatibility for torchvision 0.4.0
Browse files Browse the repository at this point in the history
  • Loading branch information
stan-haochen committed Oct 8, 2019
1 parent bb6e97b commit 0ea6ea0
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
6 changes: 3 additions & 3 deletions fcos_core/data/datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
v: k for k, v in self.json_category_id_to_contiguous_id.items()
}
self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}
self.transforms = transforms
self._transforms = transforms

def __getitem__(self, idx):
img, anno = super(COCODataset, self).__getitem__(idx)
Expand Down Expand Up @@ -90,8 +90,8 @@ def __getitem__(self, idx):

target = target.clip_to_image(remove_empty=True)

if self.transforms is not None:
img, target = self.transforms(img, target)
if self._transforms is not None:
img, target = self._transforms(img, target)

return img, target, idx

Expand Down
7 changes: 5 additions & 2 deletions fcos_core/data/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,12 @@ def get_size(self, image_size):
def __call__(self, image, target=None):
size = self.get_size(image.size)
image = F.resize(image, size)
if target is None:
if isinstance(target, list):
target = [t.resize(image.size) for t in target]
elif target is None:
return image
target = target.resize(image.size)
else:
target = target.resize(image.size)
return image, target


Expand Down
3 changes: 2 additions & 1 deletion fcos_core/structures/segmentation_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,8 @@ def __getitem__(self, item):
else:
# advanced indexing on a single dimension
selected_polygons = []
if isinstance(item, torch.Tensor) and item.dtype == torch.uint8:
if isinstance(item, torch.Tensor) and \
item.dtype == torch.uint8 or item.dtype == torch.bool:
item = item.nonzero()
item = item.squeeze(1) if item.numel() > 0 else item
item = item.tolist()
Expand Down

0 comments on commit 0ea6ea0

Please sign in to comment.