From 5e66b4cd15b2b182da347103dd16578d28b49d69 Mon Sep 17 00:00:00 2001 From: Paul Willot Date: Wed, 19 Aug 2020 18:46:29 +0900 Subject: [PATCH] Fix ONNX export for panotpic model (#180) * Fix ONNX export for panotpic model The operation `.view_as(an_other_tensor)` is not supported by ONNX, `view(an_other_tensor.size())` is identical and supported. This fix allow to export panoptic models. * Update test_all.py Add onnx export test for panoptic model * fix lint * fix lint * Skip test on OOM CI error --- models/segmentation.py | 2 +- test_all.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/models/segmentation.py b/models/segmentation.py index 7d4a0c6fd..a38b7ac33 100644 --- a/models/segmentation.py +++ b/models/segmentation.py @@ -164,7 +164,7 @@ def forward(self, q, k, mask: Optional[Tensor] = None): if mask is not None: weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf")) - weights = F.softmax(weights.flatten(2), dim=-1).view_as(weights) + weights = F.softmax(weights.flatten(2), dim=-1).view(weights.size()) weights = self.dropout(weights) return weights diff --git a/test_all.py b/test_all.py index d2847b180..a6ae2a61a 100644 --- a/test_all.py +++ b/test_all.py @@ -167,6 +167,21 @@ def test_model_onnx_detection(self): tolerate_small_mismatch=True, ) + @unittest.skip("CI doesn't have enough memory") + def test_model_onnx_detection_panoptic(self): + model = detr_resnet50_panoptic(pretrained=False).eval() + dummy_image = torch.ones(1, 3, 800, 800) * 0.3 + model(dummy_image) + + # Test exported model on images of different size, or dummy input + self.run_model( + model, + [(torch.rand(1, 3, 750, 800),)], + input_names=["inputs"], + output_names=["pred_logits", "pred_boxes", "pred_masks"], + tolerate_small_mismatch=True, + ) + if __name__ == '__main__': unittest.main()