From a5cd934faa64a4977d10ab86010d4ad9094c0a8f Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Thu, 4 Jun 2020 21:56:20 +0200 Subject: [PATCH] Add torchscript support for hub detection models (#51) * Add torchscript support for hub detection models * Dummy commit to test CI * Try trigger CI * Fix lint * Update CI to use nightlies --- .circleci/config.yml | 3 +- models/backbone.py | 15 ++++---- models/detr.py | 19 +++++++--- models/position_encoding.py | 7 ++-- models/segmentation.py | 4 +-- test_all.py | 21 +++++++++++ util/misc.py | 69 ++++++++++++++++++++++--------------- 7 files changed, 95 insertions(+), 43 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index edfc1d8d3..96857a478 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -18,7 +18,8 @@ jobs: - checkout - run: command: | - pip install --user --progress-bar off torch torchvision scipy pytest + pip install --user --progress-bar off scipy pytest + pip install --user --progress-bar off --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html pytest . workflows: diff --git a/models/backbone.py b/models/backbone.py index fcba172b0..52e32964d 100644 --- a/models/backbone.py +++ b/models/backbone.py @@ -9,6 +9,7 @@ import torchvision from torch import nn from torchvision.models._utils import IntermediateLayerGetter +from typing import Dict, List from util.misc import NestedTensor @@ -64,15 +65,17 @@ def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, if return_interm_layers: return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} else: - return_layers = {'layer4': 0} + return_layers = {'layer4': "0"} self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) self.num_channels = num_channels - def forward(self, tensor_list): + def forward(self, tensor_list: NestedTensor): xs = self.body(tensor_list.tensors) - out = OrderedDict() + out: Dict[str, NestedTensor] = {} for name, x in xs.items(): - mask = F.interpolate(tensor_list.mask[None].float(), size=x.shape[-2:]).bool()[0] + m = tensor_list.mask + assert m is not None + mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] out[name] = NestedTensor(x, mask) return out @@ -94,9 +97,9 @@ class Joiner(nn.Sequential): def __init__(self, backbone, position_embedding): super().__init__(backbone, position_embedding) - def forward(self, tensor_list): + def forward(self, tensor_list: NestedTensor): xs = self[0](tensor_list) - out = [] + out: List[NestedTensor] = [] pos = [] for name, x in xs.items(): out.append(x) diff --git a/models/detr.py b/models/detr.py index 2471da646..c63f5d739 100644 --- a/models/detr.py +++ b/models/detr.py @@ -7,7 +7,8 @@ from torch import nn from util import box_ops -from util.misc import (NestedTensor, accuracy, get_world_size, interpolate, +from util.misc import (NestedTensor, nested_tensor_from_tensor_list, + accuracy, get_world_size, interpolate, is_dist_avail_and_initialized) from .backbone import build_backbone @@ -56,20 +57,28 @@ def forward(self, samples: NestedTensor): dictionnaries containing the two above keys for each decoder layer. """ if not isinstance(samples, NestedTensor): - samples = NestedTensor.from_tensor_list(samples) + samples = nested_tensor_from_tensor_list(samples) features, pos = self.backbone(samples) src, mask = features[-1].decompose() + assert mask is not None hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0] outputs_class = self.class_embed(hs) outputs_coord = self.bbox_embed(hs).sigmoid() out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]} if self.aux_loss: - out['aux_outputs'] = [{'pred_logits': a, 'pred_boxes': b} - for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] + out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord) return out + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [{'pred_logits': a, 'pred_boxes': b} + for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] + class SetCriterion(nn.Module): """ This class computes the loss for DETR. @@ -164,7 +173,7 @@ def loss_masks(self, outputs, targets, indices, num_boxes): src_masks = outputs["pred_masks"] # TODO use valid to mask invalid areas due to padding in loss - target_masks, valid = NestedTensor.from_tensor_list([t["masks"] for t in targets]).decompose() + target_masks, valid = nested_tensor_from_tensor_list([t["masks"] for t in targets]).decompose() target_masks = target_masks.to(src_masks) src_masks = src_masks[src_idx] diff --git a/models/position_encoding.py b/models/position_encoding.py index e2997af54..73ae39edf 100644 --- a/models/position_encoding.py +++ b/models/position_encoding.py @@ -6,6 +6,8 @@ import torch from torch import nn +from util.misc import NestedTensor + class PositionEmbeddingSine(nn.Module): """ @@ -23,9 +25,10 @@ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=N scale = 2 * math.pi self.scale = scale - def forward(self, tensor_list): + def forward(self, tensor_list: NestedTensor): x = tensor_list.tensors mask = tensor_list.mask + assert mask is not None not_mask = ~mask y_embed = not_mask.cumsum(1, dtype=torch.float32) x_embed = not_mask.cumsum(2, dtype=torch.float32) @@ -59,7 +62,7 @@ def reset_parameters(self): nn.init.uniform_(self.row_embed.weight) nn.init.uniform_(self.col_embed.weight) - def forward(self, tensor_list): + def forward(self, tensor_list: NestedTensor): x = tensor_list.tensors h, w = x.shape[-2:] i = torch.arange(w, device=x.device) diff --git a/models/segmentation.py b/models/segmentation.py index 589a58ea4..970b1fa48 100644 --- a/models/segmentation.py +++ b/models/segmentation.py @@ -11,7 +11,7 @@ from PIL import Image import util.box_ops as box_ops -from util.misc import NestedTensor, interpolate +from util.misc import NestedTensor, interpolate, nested_tensor_from_tensor_list try: from panopticapi.utils import id2rgb, rgb2id @@ -34,7 +34,7 @@ def __init__(self, detr, freeze_detr=False): def forward(self, samples: NestedTensor): if not isinstance(samples, NestedTensor): - samples = NestedTensor.from_tensor_list(samples) + samples = nested_tensor_from_tensor_list(samples) features, pos = self.detr.backbone(samples) bs = features[-1].tensors.shape[0] diff --git a/test_all.py b/test_all.py index d6135a0b6..77c2a72cc 100644 --- a/test_all.py +++ b/test_all.py @@ -4,7 +4,11 @@ import torch from models.matcher import HungarianMatcher +from models.position_encoding import PositionEmbeddingSine, PositionEmbeddingLearned +from models.backbone import Backbone, Joiner, BackboneBase from util import box_ops +from util.misc import nested_tensor_from_tensor_list +from hubconf import detr_resnet50 class Tester(unittest.TestCase): @@ -47,6 +51,23 @@ def test_hungarian(self): 'pred_boxes': boxes.repeat(2, 1, 1)}, targets_empty * 2) self.assertEqual(len(indices[0][0]), 0) + def test_position_encoding_script(self): + m1, m2 = PositionEmbeddingSine(), PositionEmbeddingLearned() + mm1, mm2 = torch.jit.script(m1), torch.jit.script(m2) # noqa + + def test_backbone_script(self): + backbone = Backbone('resnet50', True, False, False) + torch.jit.script(backbone) # noqa + + def test_model_script(self): + model = detr_resnet50(pretrained=False).eval() + scripted_model = torch.jit.script(model) + x = nested_tensor_from_tensor_list([torch.rand(3, 200, 200), torch.rand(3, 200, 250)]) + out = model(x) + out_script = scripted_model(x) + self.assertTrue(out["pred_logits"].equal(out_script["pred_logits"])) + self.assertTrue(out["pred_boxes"].equal(out_script["pred_boxes"])) + if __name__ == '__main__': unittest.main() diff --git a/util/misc.py b/util/misc.py index 45d055d91..46437d697 100644 --- a/util/misc.py +++ b/util/misc.py @@ -267,45 +267,60 @@ def _run(command): def collate_fn(batch): batch = list(zip(*batch)) - batch[0] = NestedTensor.from_tensor_list(batch[0]) + batch[0] = nested_tensor_from_tensor_list(batch[0]) return tuple(batch) +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], :img.shape[2]] = False + else: + raise ValueError('not supported') + return NestedTensor(tensor, mask) + + class NestedTensor(object): - def __init__(self, tensors, mask): + def __init__(self, tensors, mask: Optional[Tensor]): self.tensors = tensors self.mask = mask - def to(self, *args, **kwargs): - cast_tensor = self.tensors.to(*args, **kwargs) - cast_mask = self.mask.to(*args, **kwargs) if self.mask is not None else None - return type(self)(cast_tensor, cast_mask) + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) def decompose(self): return self.tensors, self.mask - @classmethod - def from_tensor_list(cls, tensor_list): - # TODO make this more general - if tensor_list[0].ndim == 3: - # TODO make it support different-sized images - max_size = tuple(max(s) for s in zip(*[img.shape for img in tensor_list])) - # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) - batch_shape = (len(tensor_list),) + max_size - b, c, h, w = batch_shape - dtype = tensor_list[0].dtype - device = tensor_list[0].device - tensor = torch.zeros(batch_shape, dtype=dtype, device=device) - mask = torch.ones((b, h, w), dtype=torch.bool, device=device) - for img, pad_img, m in zip(tensor_list, tensor, mask): - pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) - m[: img.shape[1], :img.shape[2]] = False - else: - raise ValueError('not supported') - return cls(tensor, mask) - def __repr__(self): - return repr(self.tensors) + return str(self.tensors) def setup_for_distributed(is_master):