Skip to content

Commit

Permalink
Bug fixes in ONNX model tracing (facebookresearch#173)
Browse files Browse the repository at this point in the history
* Bug fixes in ONNX export

* Add onnx onnxruntime as requirements

* Add unittest of detr panoptic model onnx export

* Update onnx opset version to 12 to support operator einsum in detr panoptic models

* Fix lint

* The precision of panoptic onnx exported model exceeded the margin of error
  • Loading branch information
zhiqwang authored Aug 3, 2020
1 parent 9db8be1 commit f4cdc54
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 0 deletions.
1 change: 1 addition & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ jobs:
command: |
pip install --user --progress-bar off scipy pytest
pip install --user --progress-bar off --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
pip install --user --progress-bar off onnx onnxruntime
pytest .
workflows:
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ torch>=1.5.0
torchvision>=0.6.0
git+https://github.com/cocodataset/panopticapi.git#egg=panopticapi
scipy
onnx
onnxruntime
74 changes: 74 additions & 0 deletions test_all.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import io
import unittest

import torch
Expand All @@ -10,6 +11,12 @@
from util.misc import nested_tensor_from_tensor_list
from hubconf import detr_resnet50, detr_resnet50_panoptic

# onnxruntime requires python 3.5 or above
try:
import onnxruntime
except ImportError:
onnxruntime = None


class Tester(unittest.TestCase):

Expand Down Expand Up @@ -94,5 +101,72 @@ def test_model_detection_different_inputs(self):
self.assertIn('pred_logits', out)


@unittest.skipIf(onnxruntime is None, 'ONNX Runtime unavailable')
class ONNXExporterTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
torch.manual_seed(123)

def run_model(self, model, inputs_list, tolerate_small_mismatch=False, do_constant_folding=True, dynamic_axes=None,
output_names=None, input_names=None):
model.eval()

onnx_io = io.BytesIO()
# export to onnx with the first input
torch.onnx.export(model, inputs_list[0], onnx_io,
do_constant_folding=do_constant_folding, opset_version=12,
dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names)
# validate the exported model with onnx runtime
for test_inputs in inputs_list:
with torch.no_grad():
if isinstance(test_inputs, torch.Tensor) or isinstance(test_inputs, list):
test_inputs = (nested_tensor_from_tensor_list(test_inputs),)
test_ouputs = model(*test_inputs)
if isinstance(test_ouputs, torch.Tensor):
test_ouputs = (test_ouputs,)
self.ort_validate(onnx_io, test_inputs, test_ouputs, tolerate_small_mismatch)

def ort_validate(self, onnx_io, inputs, outputs, tolerate_small_mismatch=False):

inputs, _ = torch.jit._flatten(inputs)
outputs, _ = torch.jit._flatten(outputs)

def to_numpy(tensor):
if tensor.requires_grad:
return tensor.detach().cpu().numpy()
else:
return tensor.cpu().numpy()

inputs = list(map(to_numpy, inputs))
outputs = list(map(to_numpy, outputs))

ort_session = onnxruntime.InferenceSession(onnx_io.getvalue())
# compute onnxruntime output prediction
ort_inputs = dict((ort_session.get_inputs()[i].name, inpt) for i, inpt in enumerate(inputs))
ort_outs = ort_session.run(None, ort_inputs)
for i in range(0, len(outputs)):
try:
torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-03, atol=1e-05)
except AssertionError as error:
if tolerate_small_mismatch:
self.assertIn("(0.00%)", str(error), str(error))
else:
raise

def test_model_onnx_detection(self):
model = detr_resnet50(pretrained=False).eval()
dummy_image = torch.ones(1, 3, 800, 800) * 0.3
model(dummy_image)

# Test exported model on images of different size, or dummy input
self.run_model(
model,
[(torch.rand(1, 3, 750, 800),)],
input_names=["inputs"],
output_names=["pred_logits", "pred_boxes"],
tolerate_small_mismatch=True,
)


if __name__ == '__main__':
unittest.main()
36 changes: 36 additions & 0 deletions util/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,11 @@ def _max_by_axis(the_list):
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
# TODO make this more general
if tensor_list[0].ndim == 3:
if torchvision._is_tracing():
# nested_tensor_from_tensor_list() does not export well to ONNX
# call _onnx_nested_tensor_from_tensor_list() instead
return _onnx_nested_tensor_from_tensor_list(tensor_list)

# TODO make it support different-sized images
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
Expand All @@ -300,6 +305,37 @@ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
return NestedTensor(tensor, mask)


# _onnx_nested_tensor_from_tensor_list() is an implementation of
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
@torch.jit.unused
def _onnx_nested_tensor_from_tensor_list(tensor_list):
max_size = []
for i in range(tensor_list[0].dim()):
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
max_size.append(max_size_i)
max_size = tuple(max_size)

# work around for
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
# m[: img.shape[1], :img.shape[2]] = False
# which is not yet supported in onnx
padded_imgs = []
padded_masks = []
for img in tensor_list:
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
padded_imgs.append(padded_img)

m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
padded_masks.append(padded_mask.to(torch.bool))

tensor = torch.stack(padded_imgs)
mask = torch.stack(padded_masks)

return NestedTensor(tensor, mask=mask)


class NestedTensor(object):
def __init__(self, tensors, mask: Optional[Tensor]):
self.tensors = tensors
Expand Down

0 comments on commit f4cdc54

Please sign in to comment.