Skip to content

Commit

Permalink
Add torchscript support for hub detection models (facebookresearch#51)
Browse files Browse the repository at this point in the history
* Add torchscript support for hub detection models

* Dummy commit to test CI

* Try trigger CI

* Fix lint

* Update CI to use nightlies
  • Loading branch information
fmassa authored Jun 4, 2020
1 parent b7b62c0 commit a5cd934
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 43 deletions.
3 changes: 2 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ jobs:
- checkout
- run:
command: |
pip install --user --progress-bar off torch torchvision scipy pytest
pip install --user --progress-bar off scipy pytest
pip install --user --progress-bar off --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
pytest .
workflows:
Expand Down
15 changes: 9 additions & 6 deletions models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torchvision
from torch import nn
from torchvision.models._utils import IntermediateLayerGetter
from typing import Dict, List

from util.misc import NestedTensor

Expand Down Expand Up @@ -64,15 +65,17 @@ def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int,
if return_interm_layers:
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
else:
return_layers = {'layer4': 0}
return_layers = {'layer4': "0"}
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.num_channels = num_channels

def forward(self, tensor_list):
def forward(self, tensor_list: NestedTensor):
xs = self.body(tensor_list.tensors)
out = OrderedDict()
out: Dict[str, NestedTensor] = {}
for name, x in xs.items():
mask = F.interpolate(tensor_list.mask[None].float(), size=x.shape[-2:]).bool()[0]
m = tensor_list.mask
assert m is not None
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
out[name] = NestedTensor(x, mask)
return out

Expand All @@ -94,9 +97,9 @@ class Joiner(nn.Sequential):
def __init__(self, backbone, position_embedding):
super().__init__(backbone, position_embedding)

def forward(self, tensor_list):
def forward(self, tensor_list: NestedTensor):
xs = self[0](tensor_list)
out = []
out: List[NestedTensor] = []
pos = []
for name, x in xs.items():
out.append(x)
Expand Down
19 changes: 14 additions & 5 deletions models/detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from torch import nn

from util import box_ops
from util.misc import (NestedTensor, accuracy, get_world_size, interpolate,
from util.misc import (NestedTensor, nested_tensor_from_tensor_list,
accuracy, get_world_size, interpolate,
is_dist_avail_and_initialized)

from .backbone import build_backbone
Expand Down Expand Up @@ -56,20 +57,28 @@ def forward(self, samples: NestedTensor):
dictionnaries containing the two above keys for each decoder layer.
"""
if not isinstance(samples, NestedTensor):
samples = NestedTensor.from_tensor_list(samples)
samples = nested_tensor_from_tensor_list(samples)
features, pos = self.backbone(samples)

src, mask = features[-1].decompose()
assert mask is not None
hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]

outputs_class = self.class_embed(hs)
outputs_coord = self.bbox_embed(hs).sigmoid()
out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
if self.aux_loss:
out['aux_outputs'] = [{'pred_logits': a, 'pred_boxes': b}
for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
return out

@torch.jit.unused
def _set_aux_loss(self, outputs_class, outputs_coord):
# this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list.
return [{'pred_logits': a, 'pred_boxes': b}
for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]


class SetCriterion(nn.Module):
""" This class computes the loss for DETR.
Expand Down Expand Up @@ -164,7 +173,7 @@ def loss_masks(self, outputs, targets, indices, num_boxes):
src_masks = outputs["pred_masks"]

# TODO use valid to mask invalid areas due to padding in loss
target_masks, valid = NestedTensor.from_tensor_list([t["masks"] for t in targets]).decompose()
target_masks, valid = nested_tensor_from_tensor_list([t["masks"] for t in targets]).decompose()
target_masks = target_masks.to(src_masks)

src_masks = src_masks[src_idx]
Expand Down
7 changes: 5 additions & 2 deletions models/position_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import torch
from torch import nn

from util.misc import NestedTensor


class PositionEmbeddingSine(nn.Module):
"""
Expand All @@ -23,9 +25,10 @@ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=N
scale = 2 * math.pi
self.scale = scale

def forward(self, tensor_list):
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
mask = tensor_list.mask
assert mask is not None
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
Expand Down Expand Up @@ -59,7 +62,7 @@ def reset_parameters(self):
nn.init.uniform_(self.row_embed.weight)
nn.init.uniform_(self.col_embed.weight)

def forward(self, tensor_list):
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
h, w = x.shape[-2:]
i = torch.arange(w, device=x.device)
Expand Down
4 changes: 2 additions & 2 deletions models/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from PIL import Image

import util.box_ops as box_ops
from util.misc import NestedTensor, interpolate
from util.misc import NestedTensor, interpolate, nested_tensor_from_tensor_list

try:
from panopticapi.utils import id2rgb, rgb2id
Expand All @@ -34,7 +34,7 @@ def __init__(self, detr, freeze_detr=False):

def forward(self, samples: NestedTensor):
if not isinstance(samples, NestedTensor):
samples = NestedTensor.from_tensor_list(samples)
samples = nested_tensor_from_tensor_list(samples)
features, pos = self.detr.backbone(samples)

bs = features[-1].tensors.shape[0]
Expand Down
21 changes: 21 additions & 0 deletions test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
import torch

from models.matcher import HungarianMatcher
from models.position_encoding import PositionEmbeddingSine, PositionEmbeddingLearned
from models.backbone import Backbone, Joiner, BackboneBase
from util import box_ops
from util.misc import nested_tensor_from_tensor_list
from hubconf import detr_resnet50


class Tester(unittest.TestCase):
Expand Down Expand Up @@ -47,6 +51,23 @@ def test_hungarian(self):
'pred_boxes': boxes.repeat(2, 1, 1)}, targets_empty * 2)
self.assertEqual(len(indices[0][0]), 0)

def test_position_encoding_script(self):
m1, m2 = PositionEmbeddingSine(), PositionEmbeddingLearned()
mm1, mm2 = torch.jit.script(m1), torch.jit.script(m2) # noqa

def test_backbone_script(self):
backbone = Backbone('resnet50', True, False, False)
torch.jit.script(backbone) # noqa

def test_model_script(self):
model = detr_resnet50(pretrained=False).eval()
scripted_model = torch.jit.script(model)
x = nested_tensor_from_tensor_list([torch.rand(3, 200, 200), torch.rand(3, 200, 250)])
out = model(x)
out_script = scripted_model(x)
self.assertTrue(out["pred_logits"].equal(out_script["pred_logits"]))
self.assertTrue(out["pred_boxes"].equal(out_script["pred_boxes"]))


if __name__ == '__main__':
unittest.main()
69 changes: 42 additions & 27 deletions util/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,45 +267,60 @@ def _run(command):

def collate_fn(batch):
batch = list(zip(*batch))
batch[0] = NestedTensor.from_tensor_list(batch[0])
batch[0] = nested_tensor_from_tensor_list(batch[0])
return tuple(batch)


def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int]
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes


def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
# TODO make this more general
if tensor_list[0].ndim == 3:
# TODO make it support different-sized images
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
batch_shape = [len(tensor_list)] + max_size
b, c, h, w = batch_shape
dtype = tensor_list[0].dtype
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
for img, pad_img, m in zip(tensor_list, tensor, mask):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
m[: img.shape[1], :img.shape[2]] = False
else:
raise ValueError('not supported')
return NestedTensor(tensor, mask)


class NestedTensor(object):
def __init__(self, tensors, mask):
def __init__(self, tensors, mask: Optional[Tensor]):
self.tensors = tensors
self.mask = mask

def to(self, *args, **kwargs):
cast_tensor = self.tensors.to(*args, **kwargs)
cast_mask = self.mask.to(*args, **kwargs) if self.mask is not None else None
return type(self)(cast_tensor, cast_mask)
def to(self, device):
# type: (Device) -> NestedTensor # noqa
cast_tensor = self.tensors.to(device)
mask = self.mask
if mask is not None:
assert mask is not None
cast_mask = mask.to(device)
else:
cast_mask = None
return NestedTensor(cast_tensor, cast_mask)

def decompose(self):
return self.tensors, self.mask

@classmethod
def from_tensor_list(cls, tensor_list):
# TODO make this more general
if tensor_list[0].ndim == 3:
# TODO make it support different-sized images
max_size = tuple(max(s) for s in zip(*[img.shape for img in tensor_list]))
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
batch_shape = (len(tensor_list),) + max_size
b, c, h, w = batch_shape
dtype = tensor_list[0].dtype
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
for img, pad_img, m in zip(tensor_list, tensor, mask):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
m[: img.shape[1], :img.shape[2]] = False
else:
raise ValueError('not supported')
return cls(tensor, mask)

def __repr__(self):
return repr(self.tensors)
return str(self.tensors)


def setup_for_distributed(is_master):
Expand Down

0 comments on commit a5cd934

Please sign in to comment.