diff --git a/test/onnx/symbolic_opsets/test_symbolic_opset9.py b/test/onnx/symbolic_opsets/test_symbolic_opset9.py deleted file mode 100644 index a8334838115ca..0000000000000 --- a/test/onnx/symbolic_opsets/test_symbolic_opset9.py +++ /dev/null @@ -1,32 +0,0 @@ -# Owner(s): ["module: onnx"] - -"""Tests for `torch.onnx.symbolic_opset9`.""" -import torch -from torch import _C -from torch.onnx import symbolic_opset9 as opset9 -from torch.testing._internal import common_utils - - -class TestPrim(common_utils.TestCase): - def setUp(self): - super().setUp() - self.graph = _C.Graph() - - def test_list_unpack_returns_all_list_elements_when_previous_node_is_list_construct( - self, - ): - # Build the graph - input_1 = self.graph.addInput() - input_1.setType(input_1.type().with_dtype(torch.float).with_sizes([2, 42])) - input_2 = self.graph.addInput() - input_2.setType(input_2.type().with_dtype(torch.float).with_sizes([3, 42])) - constructed_list = self.graph.op("prim::ListConstruct", input_1, input_2) - # Test the op - outputs = opset9.Prim.ListUnpack(self.graph, constructed_list) - self.assertNotEqual(outputs, None) - self.assertEqual(outputs[0], input_1) - self.assertEqual(outputs[1], input_2) - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/torch/onnx/README.md b/torch/onnx/README.md index cb190ba1e496e..80a282c037de3 100644 --- a/torch/onnx/README.md +++ b/torch/onnx/README.md @@ -5,3 +5,21 @@ Torch->ONNX converter / exporter. [User-facing docs](https://pytorch.org/docs/master/onnx.html). [Developer docs](https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter). + +## Symbolic functions Opsets + +Opset 9 is the base version. It is selected as the base version because + +1. It is the first opset version supported by PyTorch export. +2. Opset 9 is more robust than previous opset versions. Opset versions like 7/8 have limitations + that certain basic operators cannot be expressed in ONNX. Instead of basing on these limitations, + we chose to handle them as special cases separately. + +Backward support for opset versions beyond opset 7 is not in our roadmap. + +For opset versions other than 9, by default they will inherit the symbolic functions defined in +symbolic_opset9.py. + +To extend support for updated operators in different opset versions on top of opset 9, +simply add the updated symbolic functions in the respective symbolic_opset{version}.py file. +Checkout topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.py for example. diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 351fd342cfc9a..6e5424dcfd2ba 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -1,5 +1,4 @@ """ONNX exporter.""" -import warnings from torch import _C from torch._C import _onnx as _C_onnx @@ -26,6 +25,7 @@ symbolic_opset14, symbolic_opset15, symbolic_opset16, + symbolic_opset17, utils, ) from ._exporter_states import ExportTypes, SymbolicContext @@ -60,6 +60,7 @@ "symbolic_opset14", "symbolic_opset15", "symbolic_opset16", + "symbolic_opset17", # Enums "ExportTypes", "OperatorExportTypes", @@ -133,6 +134,3 @@ def log(*args) -> None: character appended to the end, and flushed to output stream. """ _C._jit_onnx_log(*args) - - -_registration.discover_and_register_all_symbolic_opsets() diff --git a/torch/onnx/_internal/registration.py b/torch/onnx/_internal/registration.py index 03c0e192c1c25..e2c2cd3dcd749 100644 --- a/torch/onnx/_internal/registration.py +++ b/torch/onnx/_internal/registration.py @@ -1,7 +1,5 @@ """Module for handling symbolic function registration.""" -import importlib -import inspect import warnings from typing import ( Callable, @@ -265,49 +263,6 @@ def all_functions(self) -> Set[str]: return set(self._registry) -def discover_and_register_all_symbolic_opsets() -> None: - """Discover all symbolic functions. - Opset 9 is the base version. It is selected as the base version because - 1. It is the first opset version supported by PyTorch export. - 2. opset 9 is more robust than previous opset versions. Opset versions like 7/8 have limitations - that certain basic operators cannot be expressed in ONNX. Instead of basing on these limitations, - we chose to handle them as special cases separately. - - Backward support for opset versions beyond opset 7 is not in our roadmap. - For opset versions other than 9, by default they will inherit the symbolic functions defined in - symbolic_opset9.py. - - To extend support for updated operators in different opset versions on top of opset 9, - simply add the updated symbolic functions in the respective symbolic_opset{version}.py file. - Checkout topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.py for example. - """ - for opset in range(_constants.ONNX_MIN_OPSET, _constants.ONNX_MAX_OPSET + 1): - module = importlib.import_module(f"torch.onnx.symbolic_opset{opset}") - _register_module(module, opset) - - -def _register_module(module, opset: OpsetVersion) -> None: - """Registers all functions in the given module. - - Args: - module: The module to register. - opset: The opset version to register. - """ - global registry - members = inspect.getmembers(module) - for name, obj in members: - if isinstance(obj, type) and hasattr(obj, "domain"): - # Symbolic functions in domains other than aten - ops = inspect.getmembers(obj, predicate=inspect.isfunction) - for op in ops: - registry.register(f"{obj.domain}::{op[0]}", opset, op[1]) # type: ignore[attr-defined] - - elif inspect.isfunction(obj): - if name in {"_len", "_list", "_any", "_all"}: - name = name[1:] - registry.register(f"aten::{name}", opset, obj) - - @_beartype.beartype def onnx_symbolic( name: str, diff --git a/torch/onnx/symbolic_caffe2.py b/torch/onnx/symbolic_caffe2.py index cb1c48a580e99..feeb566af132e 100644 --- a/torch/onnx/symbolic_caffe2.py +++ b/torch/onnx/symbolic_caffe2.py @@ -176,7 +176,7 @@ def upsample_nearest2d( g, input, output_size, align_corners=None, scales_h=None, scales_w=None ): if input not in symbolic_helper._quantized_ops: - return opset9.upsample_nearest2d(g, input, output_size, align_corners) + return opset9.upsample_nearest2d(g, input, output_size, align_corners) # type: ignore[attr-defined] output_size = symbolic_helper._parse_arg(output_size, "is") kwargs = { @@ -194,7 +194,7 @@ def upsample_nearest2d( @symbolic_helper.parse_args("v", "is", "is", "is", "is", "i") def max_pool2d(g, input, kernel_size, stride, padding, dilation, ceil_mode): if input not in symbolic_helper._quantized_ops: - return opset9.max_pool2d( + return opset9.max_pool2d( # type: ignore[attr-defined] g, input, kernel_size, stride, padding, dilation, ceil_mode ) kwargs = { @@ -224,7 +224,7 @@ def avg_pool2d( divisor_override=None, ): if input not in symbolic_helper._quantized_ops: - return opset9.avg_pool2d( + return opset9.avg_pool2d( # type: ignore[attr-defined] g, input, kernel_size, diff --git a/torch/onnx/symbolic_opset10.py b/torch/onnx/symbolic_opset10.py index 46b6f089e13b3..e3086ca65e7e6 100644 --- a/torch/onnx/symbolic_opset10.py +++ b/torch/onnx/symbolic_opset10.py @@ -1,6 +1,7 @@ +import functools import sys import warnings -from typing import Sequence +from typing import Callable, Sequence import torch import torch._C._onnx as _C_onnx @@ -16,7 +17,7 @@ symbolic_opset9 as opset9, ) from torch.onnx._globals import GLOBALS -from torch.onnx._internal import _beartype +from torch.onnx._internal import _beartype, registration # EDITING THIS FILE? READ THIS FIRST! # see Note [Edit Symbolic Files] in symbolic_helper.py @@ -27,9 +28,6 @@ __all__ = [ - "avg_pool1d", - "avg_pool2d", - "avg_pool3d", "dequantize", "div", "embedding_bag", @@ -38,27 +36,40 @@ "fmod", "isfinite", "isinf", - "max_pool1d_with_indices", - "max_pool1d", - "max_pool2d_with_indices", - "max_pool2d", - "max_pool3d_with_indices", - "max_pool3d", "nan_to_num", "quantize_per_tensor", - "Quantized", + "quantized_add_relu", + "quantized_add", + "quantized_cat", + "quantized_conv2d_relu", + "quantized_conv2d", + "quantized_group_norm", + "quantized_hardswish", + "quantized_instance_norm", + "quantized_layer_norm", + "quantized_leaky_relu", + "quantized_linear", + "quantized_mul", + "quantized_sigmoid", "slice", "sort", "topk", - "upsample_bilinear2d", - "upsample_linear1d", - "upsample_nearest1d", - "upsample_nearest2d", - "upsample_nearest3d", - "upsample_trilinear3d", ] +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=10) + + +def _apply_params(*args, **kwargs): + """Returns a decorator that calls the decorated (higher-order) function with the given parameters.""" + + def _apply(fn): + return fn(*args, **kwargs) + + return _apply + + +@_onnx_symbolic("aten::div") @_beartype.beartype def div(g, self, other, *args): if len(args) == 0: @@ -76,6 +87,7 @@ def _div_rounding_mode(g, self, other, rounding_mode): return opset9._div_rounding_mode(g, self, other, rounding_mode) +@_onnx_symbolic("aten::_floor_divide") @_beartype.beartype def _floor_divide(g, self, other): if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): @@ -97,12 +109,14 @@ def _floor_divide(g, self, other): return g.op("Where", fixup_mask, fixup, div) +@_onnx_symbolic("aten::sort") @symbolic_helper.parse_args("v", "i", "i", "none") @_beartype.beartype def sort(g, self, dim, decending, out=None): return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out) +@_onnx_symbolic("aten::topk") @symbolic_helper.parse_args("v", "v", "i", "i", "i", "none") @_beartype.beartype def topk(g, self, k, dim, largest, sorted, out=None): @@ -111,11 +125,67 @@ def topk(g, self, k, dim, largest, sorted, out=None): ) +@_onnx_symbolic( + "aten::max_pool1d", + decorate=[ + _apply_params( + "max_pool1d", torch.nn.modules.utils._single, 1, return_indices=False + ) + ], +) +@_onnx_symbolic( + "aten::max_pool2d", + decorate=[ + _apply_params( + "max_pool2d", torch.nn.modules.utils._pair, 2, return_indices=False + ) + ], +) +@_onnx_symbolic( + "aten::max_pool3d", + decorate=[ + _apply_params( + "max_pool3d", torch.nn.modules.utils._triple, 3, return_indices=False + ) + ], +) +@_onnx_symbolic( + "aten::max_pool1d_with_indices", + decorate=[ + _apply_params( + "max_pool1d_with_indices", + torch.nn.modules.utils._single, + 1, + return_indices=True, + ) + ], +) +@_onnx_symbolic( + "aten::max_pool2d_with_indices", + decorate=[ + _apply_params( + "max_pool2d_with_indices", + torch.nn.modules.utils._pair, + 2, + return_indices=True, + ) + ], +) +@_onnx_symbolic( + "aten::max_pool3d_with_indices", + decorate=[ + _apply_params( + "max_pool3d_with_indices", + torch.nn.modules.utils._triple, + 3, + return_indices=True, + ) + ], +) @_beartype.beartype -def _max_pool(name, tuple_fn, ndims, return_indices): +def _max_pool(name: str, tuple_fn: Callable, ndims: int, return_indices: bool): @symbolic_helper.quantized_args(True, False, False, False, False, False) @symbolic_helper.parse_args("v", "is", "is", "is", "is", "i") - @_beartype.beartype def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode): if not stride: stride = kernel_size @@ -166,26 +236,18 @@ def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode): return symbolic_fn -max_pool1d = _max_pool( - "max_pool1d", torch.nn.modules.utils._single, 1, return_indices=False -) -max_pool2d = _max_pool( - "max_pool2d", torch.nn.modules.utils._pair, 2, return_indices=False -) -max_pool3d = _max_pool( - "max_pool3d", torch.nn.modules.utils._triple, 3, return_indices=False +@_onnx_symbolic( + "aten::avg_pool1d", + decorate=[_apply_params("avg_pool1d", torch.nn.modules.utils._single)], ) -max_pool1d_with_indices = _max_pool( - "max_pool1d_with_indices", torch.nn.modules.utils._single, 1, return_indices=True +@_onnx_symbolic( + "aten::avg_pool2d", + decorate=[_apply_params("avg_pool2d", torch.nn.modules.utils._pair)], ) -max_pool2d_with_indices = _max_pool( - "max_pool2d_with_indices", torch.nn.modules.utils._pair, 2, return_indices=True +@_onnx_symbolic( + "aten::avg_pool3d", + decorate=[_apply_params("avg_pool3d", torch.nn.modules.utils._triple)], ) -max_pool3d_with_indices = _max_pool( - "max_pool3d_with_indices", torch.nn.modules.utils._triple, 3, return_indices=True -) - - @_beartype.beartype def _avg_pool(name, tuple_fn): @symbolic_helper.quantized_args(True, False, False, False, False, False, False) @@ -208,7 +270,7 @@ def symbolic_fn( ) assert isinstance(padding, tuple) if count_include_pad: - input = opset9.op_with_optional_float_cast( + input = opset9._op_with_optional_float_cast( g, "Pad", input, @@ -231,11 +293,30 @@ def symbolic_fn( return symbolic_fn -avg_pool1d = _avg_pool("avg_pool1d", torch.nn.modules.utils._single) -avg_pool2d = _avg_pool("avg_pool2d", torch.nn.modules.utils._pair) -avg_pool3d = _avg_pool("avg_pool3d", torch.nn.modules.utils._triple) - - +@_onnx_symbolic( + "aten::upsample_nearest1d", + decorate=[_apply_params("upsample_nearest1d", 3, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_nearest2d", + decorate=[_apply_params("upsample_nearest2d", 4, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_nearest3d", + decorate=[_apply_params("upsample_nearest3d", 5, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_linear1d", + decorate=[_apply_params("upsample_linear1d", 3, "linear")], +) +@_onnx_symbolic( + "aten::upsample_bilinear2d", + decorate=[_apply_params("upsample_bilinear2d", 4, "linear")], +) +@_onnx_symbolic( + "aten::upsample_trilinear3d", + decorate=[_apply_params("upsample_trilinear3d", 5, "linear")], +) @_beartype.beartype def _interpolate(name, dim, interpolate_mode): @symbolic_helper.quantized_args(True, False, False) @@ -257,14 +338,7 @@ def symbolic_fn(g, input, output_size, *args): return symbolic_fn -upsample_nearest1d = _interpolate("upsample_nearest1d", 3, "nearest") -upsample_nearest2d = _interpolate("upsample_nearest2d", 4, "nearest") -upsample_nearest3d = _interpolate("upsample_nearest3d", 5, "nearest") -upsample_linear1d = _interpolate("upsample_linear1d", 3, "linear") -upsample_bilinear2d = _interpolate("upsample_bilinear2d", 4, "linear") -upsample_trilinear3d = _interpolate("upsample_trilinear3d", 5, "linear") - - +@_onnx_symbolic("aten::__interpolate") @_beartype.beartype def __interpolate( g, input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias @@ -303,6 +377,7 @@ def _slice(g, input, axes, starts, ends, steps=None, dynamic_slice=False): return g.op("Slice", input, starts, ends, axes, steps) +@_onnx_symbolic("aten::slice") @_beartype.beartype def slice(g, self, *args): if len(args) == 4: @@ -351,6 +426,7 @@ def slice(g, self, *args): ) +@_onnx_symbolic("aten::flip") @symbolic_helper.parse_args("v", "is") @_beartype.beartype def flip(g, input, dims): @@ -364,11 +440,13 @@ def flip(g, input, dims): ) +@_onnx_symbolic("aten::fmod") @_beartype.beartype def fmod(g, input, other): return g.op("Mod", input, other, fmod_i=1) +@_onnx_symbolic("aten::embedding_bag") @symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") @_beartype.beartype def embedding_bag( @@ -455,6 +533,7 @@ def embedding_bag( ) +@_onnx_symbolic("aten::fake_quantize_per_tensor_affine") @symbolic_helper.parse_args("v", "v", "v", "i", "i") @_beartype.beartype def fake_quantize_per_tensor_affine( @@ -498,20 +577,21 @@ def fake_quantize_per_tensor_affine( ) +@_onnx_symbolic("aten::isinf") @_beartype.beartype def isinf(g, input): return g.op("IsInf", opset9._cast_Double(g, input, False)) # type: ignore[attr-defined] +@_onnx_symbolic("aten::isfinite") @_beartype.beartype def isfinite(g, input): - from torch.onnx.symbolic_opset9 import __not_, __or_ - inf_node = isinf(g, input) nan_node = opset9.isnan(g, input) - return __not_(g, __or_(g, inf_node, nan_node)) + return opset9.__not_(g, opset9.__or_(g, inf_node, nan_node)) +@_onnx_symbolic("aten::quantize_per_tensor") @_beartype.beartype def quantize_per_tensor(g, input, scale, zero_point, dtype): dtype = symbolic_helper._get_const(dtype, "i", "dtype") @@ -523,11 +603,13 @@ def quantize_per_tensor(g, input, scale, zero_point, dtype): return symbolic_helper.quantize_helper(g, input, scale, zero_point) +@_onnx_symbolic("aten::dequantize") @_beartype.beartype def dequantize(g, input): return symbolic_helper.dequantize_helper(g, input)[0] +@_onnx_symbolic("aten::nan_to_num") @symbolic_helper.parse_args("v", "f", "f", "f") @_beartype.beartype def nan_to_num(g, input, nan, posinf, neginf): @@ -580,194 +662,192 @@ def nan_to_num(g, input, nan, posinf, neginf): ) +# Quantized symbolics --------------------------------------------------------- # https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export -class Quantized: - """ - https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export +# Support starts from opset 10 because `DequantizeLinear` and `QuantizeLinear` were +# introduced in opset version 10. +@_onnx_symbolic("quantized::linear") +@_beartype.beartype +def quantized_linear(g, q_input, q_weight, bias, op_scale, op_zero_point): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) - Support starts from opset 10 because `DequantizeLinear` and `QuantizeLinear` were introduced in opset version 10. - """ + output = opset9.linear(g, input, weight, bias) - domain = "quantized" + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - @staticmethod - @_beartype.beartype - def linear(g, q_input, q_weight, bias, op_scale, op_zero_point): - input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) - weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) - q_bias = symbolic_helper.requantize_bias_helper( - g, bias, input_scale, weight_scale - ) - bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) - output = opset9.linear(g, input, weight, bias) +@_onnx_symbolic("quantized::add") +@_beartype.beartype +def quantized_add(g, x, y, op_scale, op_zero_point): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + y, _, _, _ = symbolic_helper.dequantize_helper(g, y) - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + output = opset9.add(g, x, y) - @staticmethod - @_beartype.beartype - def add(g, x, y, op_scale, op_zero_point): - x, _, _, _ = symbolic_helper.dequantize_helper(g, x) - y, _, _, _ = symbolic_helper.dequantize_helper(g, y) + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - output = opset9.add(g, x, y) - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) +@_onnx_symbolic("quantized::add_relu") +@_beartype.beartype +def quantized_add_relu(g, x, y, op_scale, op_zero_point): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + y, _, _, _ = symbolic_helper.dequantize_helper(g, y) - @staticmethod - @_beartype.beartype - def add_relu(g, x, y, op_scale, op_zero_point): - x, _, _, _ = symbolic_helper.dequantize_helper(g, x) - y, _, _, _ = symbolic_helper.dequantize_helper(g, y) + output = opset9.add(g, x, y) + output = opset9.relu(g, output) - output = opset9.add(g, x, y) - output = opset9.relu(g, output) + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - @staticmethod - @_beartype.beartype - def mul(g, x, y, op_scale, op_zero_point): - x, _, _, _ = symbolic_helper.dequantize_helper(g, x) - y, _, _, _ = symbolic_helper.dequantize_helper(g, y) +@_onnx_symbolic("quantized::mul") +@_beartype.beartype +def quantized_mul(g, x, y, op_scale, op_zero_point): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + y, _, _, _ = symbolic_helper.dequantize_helper(g, y) - output = opset9.mul(g, x, y) + output = opset9.mul(g, x, y) - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - @staticmethod - @_beartype.beartype - def hardswish(g, x, op_scale, op_zero_point): - x, _, _, _ = symbolic_helper.dequantize_helper(g, x) - output = opset9.hardswish(g, x) +@_onnx_symbolic("quantized::hardswish") +@_beartype.beartype +def quantized_hardswish(g, x, op_scale, op_zero_point): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + output = opset9.hardswish(g, x) - @staticmethod - @_beartype.beartype - def sigmoid(g, x, op_scale, op_zero_point): - x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - output = opset9.sigmoid(g, x) - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) +@_onnx_symbolic("quantized::sigmoid") +@_beartype.beartype +def quantized_sigmoid(g, x, op_scale, op_zero_point): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) - @staticmethod - @_beartype.beartype - def leaky_relu(g, x, negative_slope, inplace, op_scale, op_zero_point): - x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + output = opset9.sigmoid(g, x) - output = opset9.leaky_relu(g, x, negative_slope, inplace) + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - @staticmethod - @_beartype.beartype - def layer_norm(g, x, normalized_shape, weight, bias, eps, op_scale, op_zero_point): - x, _, _, _ = symbolic_helper.dequantize_helper(g, x) +@_onnx_symbolic("quantized::leaky_relu") +@_beartype.beartype +def quantized_leaky_relu(g, x, negative_slope, inplace, op_scale, op_zero_point): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) - output = opset9.layer_norm(g, x, normalized_shape, weight, bias, eps, False) + output = opset9.leaky_relu(g, x, negative_slope, inplace) - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - @staticmethod - @_beartype.beartype - def group_norm(g, x, num_groups, weight, bias, eps, op_scale, op_zero_point): - x, _, _, _ = symbolic_helper.dequantize_helper(g, x) - output = opset9.group_norm(g, x, num_groups, weight, bias, eps, False) +@_onnx_symbolic("quantized::layer_norm") +@_beartype.beartype +def quantized_layer_norm( + g, x, normalized_shape, weight, bias, eps, op_scale, op_zero_point +): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + output = opset9.layer_norm(g, x, normalized_shape, weight, bias, eps, False) - @staticmethod - @symbolic_helper.parse_args("v", "v", "v", "f", "v", "v") - @_beartype.beartype - def instance_norm( - g, - q_input, - weight, - bias, - eps, - op_scale, - op_zero_point, - ): - input, _, _, _ = symbolic_helper.dequantize_helper(g, q_input) + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - output = opset9.instance_norm( - g, input, weight, bias, None, None, False, 0.0, eps, False - ) - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) +@_onnx_symbolic("quantized::group_norm") +@_beartype.beartype +def quantized_group_norm(g, x, num_groups, weight, bias, eps, op_scale, op_zero_point): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) - @staticmethod - @_beartype.beartype - def conv2d_relu( - g, - q_input, - q_weight, - bias, - stride, - padding, - dilation, - groups, - op_scale, - op_zero_point, - ): - input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) - weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) - q_bias = symbolic_helper.requantize_bias_helper( - g, bias, input_scale, weight_scale - ) - bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + output = opset9.group_norm(g, x, num_groups, weight, bias, eps, False) - output = opset9.conv2d( - g, input, weight, bias, stride, padding, dilation, groups - ) - output = opset9.relu(g, output) + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - @staticmethod - @_beartype.beartype - def conv2d( - g, - q_input, - q_weight, - bias, - stride, - padding, - dilation, - groups, - op_scale, - op_zero_point, - ): - input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) - weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) - q_bias = symbolic_helper.requantize_bias_helper( - g, bias, input_scale, weight_scale - ) - bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) +@_onnx_symbolic("quantized::instance_norm") +@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v") +@_beartype.beartype +def quantized_instance_norm( + g, + q_input, + weight, + bias, + eps, + op_scale, + op_zero_point, +): + input, _, _, _ = symbolic_helper.dequantize_helper(g, q_input) - output = opset9.conv2d( - g, input, weight, bias, stride, padding, dilation, groups - ) + output = opset9.instance_norm( + g, input, weight, bias, None, None, False, 0.0, eps, False + ) - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - @staticmethod - @symbolic_helper.parse_args("v", "i", "v", "v") - @_beartype.beartype - def cat( - g, - q_inputs: _C.Value, - dim: int, - op_scale: _C.Value, - op_zero_point: _C.Value, - ) -> _C.Value: - unpacked_inputs = symbolic_helper._unpack_list(q_inputs) - dequantized = [ - symbolic_helper.dequantize_helper(g, input)[0] for input in unpacked_inputs - ] - concatenated = g.op("Concat", *dequantized, axis_i=dim) - return symbolic_helper.quantize_helper(g, concatenated, op_scale, op_zero_point) + +@_onnx_symbolic("quantized::conv2d_relu") +@_beartype.beartype +def quantized_conv2d_relu( + g, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv2d") +@_beartype.beartype +def quantized_conv2d( + g, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::cat") +@symbolic_helper.parse_args("v", "i", "v", "v") +@_beartype.beartype +def quantized_cat( + g, + q_inputs: _C.Value, + dim: int, + op_scale: _C.Value, + op_zero_point: _C.Value, +) -> _C.Value: + unpacked_inputs = symbolic_helper._unpack_list(q_inputs) + dequantized = [ + symbolic_helper.dequantize_helper(g, input)[0] for input in unpacked_inputs + ] + concatenated = g.op("Concat", *dequantized, axis_i=dim) + return symbolic_helper.quantize_helper(g, concatenated, op_scale, op_zero_point) diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py index cb43a54cb6e99..6ec60bdbf8de4 100644 --- a/torch/onnx/symbolic_opset11.py +++ b/torch/onnx/symbolic_opset11.py @@ -1,5 +1,6 @@ """This file exports ONNX ops for opset 11.""" +import functools import sys import warnings from typing import Optional, Sequence, Union @@ -16,7 +17,7 @@ utils, ) from torch.onnx._globals import GLOBALS -from torch.onnx._internal import _beartype +from torch.onnx._internal import _beartype, registration # EDITING THIS FILE? READ THIS FIRST! # see Note [Edit Symbolic Files] in symbolic_helper.py @@ -26,9 +27,6 @@ "append", "arange", "argsort", - "avg_pool1d", - "avg_pool2d", - "avg_pool3d", "cat", "chunk", "clamp_max", @@ -59,17 +57,11 @@ "pad", "pixel_shuffle", "pop", - "Prim", + "prim_constant_chunk", "reflection_pad", - "reflection_pad1d", - "reflection_pad2d", - "reflection_pad3d", "relu6", "remainder", "replication_pad", - "replication_pad1d", - "replication_pad2d", - "replication_pad3d", "round", "scatter", "select", @@ -83,16 +75,21 @@ "unbind", "unique_dim", "unsqueeze", - "upsample_bicubic2d", - "upsample_bilinear2d", - "upsample_linear1d", - "upsample_nearest1d", - "upsample_nearest2d", - "upsample_nearest3d", - "upsample_trilinear3d", ] +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=11) + +def _apply_params(*args, **kwargs): + """Returns a decorator that calls the decorated (higher-order) function with the given parameters.""" + + def _apply(fn): + return fn(*args, **kwargs) + + return _apply + + +@_onnx_symbolic("aten::hardtanh") @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", "f", "f") @_beartype.beartype @@ -110,11 +107,12 @@ def hardtanh(g, self: _C.Value, min_val: float, max_val: float): "Constant", value_t=torch.tensor(max_val, dtype=scalar_type.dtype()), ) - return opset9.op_with_optional_float_cast( + return opset9._op_with_optional_float_cast( g, "Clip", self, min_val, max_val, opset_before=12 ) +@_onnx_symbolic("aten::clamp") @_beartype.beartype def clamp(g, self, min, max): dtype = self.type().scalarType() @@ -143,13 +141,14 @@ def _cast_if_not_none(tensor, dtype): symbolic_helper._get_tensor_rank(min) == 0 and symbolic_helper._get_tensor_rank(max) == 0 ): - return opset9.op_with_optional_float_cast( + return opset9._op_with_optional_float_cast( g, "Clip", self, min, max, opset_before=12 ) else: return clamp_max(g, clamp_min(g, self, min), max) +@_onnx_symbolic("aten::clamp_min") @symbolic_helper.parse_args("v", "v") @_beartype.beartype def clamp_min(g, self, min): @@ -157,13 +156,14 @@ def clamp_min(g, self, min): min = g.op("Cast", min, to_i=_type_utils.JitScalarType.from_name(dtype).onnx_type()) if symbolic_helper._get_tensor_rank(min) == 0: max = opset9.unused(g) - return opset9.op_with_optional_float_cast( + return opset9._op_with_optional_float_cast( g, "Clip", self, min, max, opset_before=12 ) else: - return opset9.op_with_optional_float_cast(g, "Max", self, min, opset_before=12) + return opset9._op_with_optional_float_cast(g, "Max", self, min, opset_before=12) +@_onnx_symbolic("aten::clamp_max") @symbolic_helper.parse_args("v", "v") @_beartype.beartype def clamp_max(g, self, max): @@ -171,16 +171,17 @@ def clamp_max(g, self, max): max = g.op("Cast", max, to_i=_type_utils.JitScalarType.from_name(dtype).onnx_type()) if symbolic_helper._get_tensor_rank(max) == 0: min = opset9.unused(g) - return opset9.op_with_optional_float_cast( + return opset9._op_with_optional_float_cast( g, "Clip", self, min, max, opset_before=12 ) else: - return opset9.op_with_optional_float_cast(g, "Min", self, max, opset_before=12) + return opset9._op_with_optional_float_cast(g, "Min", self, max, opset_before=12) +@_onnx_symbolic("aten::relu6") @_beartype.beartype def relu6(g, input): - relu_ = opset9.op_with_optional_float_cast(g, "Relu", input, opset_before=14) + relu_ = opset9._op_with_optional_float_cast(g, "Relu", input, opset_before=14) dtype = input.type().scalarType() if dtype is None: scalar_type = _type_utils.JitScalarType.FLOAT @@ -197,6 +198,7 @@ def relu6(g, input): return clamp(g, relu_, min_val, max_val) +@_onnx_symbolic("aten::select") # Opset 11 gather accepts negative indices @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", "i", "v") @@ -205,6 +207,7 @@ def select(g, self, dim, index): return g.op("Gather", self, index, axis_i=dim) +@_onnx_symbolic("aten::index_put") @_beartype.beartype def index_put(g, self, indices_list_value, values, accumulate=False): if symbolic_helper._is_packed_list(indices_list_value): @@ -315,6 +318,7 @@ def index_put(g, self, indices_list_value, values, accumulate=False): return result +@_onnx_symbolic("aten::pixel_shuffle") @symbolic_helper.parse_args("v", "i") @_beartype.beartype def pixel_shuffle(g, self, upscale_factor): @@ -324,28 +328,40 @@ def pixel_shuffle(g, self, upscale_factor): return g.op("DepthToSpace", self, blocksize_i=upscale_factor, mode_s="CRD") +@_onnx_symbolic( + "aten::upsample_nearest1d", + decorate=[_apply_params("upsample_nearest1d", 3, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_nearest2d", + decorate=[_apply_params("upsample_nearest2d", 4, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_nearest3d", + decorate=[_apply_params("upsample_nearest3d", 5, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_linear1d", + decorate=[_apply_params("upsample_linear1d", 3, "linear")], +) +@_onnx_symbolic( + "aten::upsample_bilinear2d", + decorate=[_apply_params("upsample_bilinear2d", 4, "linear")], +) +@_onnx_symbolic( + "aten::upsample_trilinear3d", + decorate=[_apply_params("upsample_trilinear3d", 5, "linear")], +) +@_onnx_symbolic( + "aten::upsample_bicubic2d", + decorate=[_apply_params("upsample_bicubic2d", 4, "cubic")], +) @_beartype.beartype -def _interpolate(name, dim, interpolate_mode): +def _interpolate(name: str, dim: int, interpolate_mode: str): return symbolic_helper._interpolate_helper(name, dim, interpolate_mode) -upsample_nearest1d = _interpolate("upsample_nearest1d", 3, "nearest") -upsample_nearest2d = _interpolate("upsample_nearest2d", 4, "nearest") -upsample_nearest3d = _interpolate("upsample_nearest3d", 5, "nearest") -upsample_linear1d = _interpolate("upsample_linear1d", 3, "linear") -upsample_bilinear2d = _interpolate("upsample_bilinear2d", 4, "linear") -upsample_trilinear3d = _interpolate("upsample_trilinear3d", 5, "linear") -upsample_bicubic2d = _interpolate("upsample_bicubic2d", 4, "cubic") - -upsample_nearest1d.__module__ = "torch.onnx.symbolic_opset11" -upsample_nearest2d.__module__ = "torch.onnx.symbolic_opset11" -upsample_nearest3d.__module__ = "torch.onnx.symbolic_opset11" -upsample_linear1d.__module__ = "torch.onnx.symbolic_opset11" -upsample_bilinear2d.__module__ = "torch.onnx.symbolic_opset11" -upsample_trilinear3d.__module__ = "torch.onnx.symbolic_opset11" -upsample_bicubic2d.__module__ = "torch.onnx.symbolic_opset11" - - +@_onnx_symbolic("aten::__interpolate") @symbolic_helper.quantized_args(True, False, False, False, False, False, False) @_beartype.beartype def __interpolate( @@ -356,6 +372,7 @@ def __interpolate( ) +@_onnx_symbolic("aten::gather") @symbolic_helper.parse_args("v", "i", "v", "v") @_beartype.beartype def gather(g, self, dim, index, sparse_grad=False): @@ -366,6 +383,7 @@ def gather(g, self, dim, index, sparse_grad=False): return g.op("GatherElements", self, index, axis_i=dim) +@_onnx_symbolic("aten::scatter") @symbolic_helper.parse_args("v", "i", "v", "v") @_beartype.beartype def scatter(g, self, dim, index, src): @@ -391,6 +409,7 @@ def scatter(g, self, dim, index, src): ) +@_onnx_symbolic("aten::cumsum") @symbolic_helper.parse_args("v", "i", "none") @_beartype.beartype def cumsum(g, self, dim, dtype=None): @@ -406,12 +425,14 @@ def cumsum(g, self, dim, dtype=None): return csum +@_onnx_symbolic("aten::masked_select") @_beartype.beartype def masked_select(g, self, mask): index = opset9.nonzero(g, opset9.expand_as(g, mask, self)) return g.op("GatherND", self, index) +@_onnx_symbolic("aten::masked_scatter") @_beartype.beartype def masked_scatter(g, self, mask, source): index = opset9.nonzero(g, opset9.expand_as(g, mask, self)) @@ -430,6 +451,7 @@ def masked_scatter(g, self, mask, source): return g.op("ScatterND", self, index, source) +@_onnx_symbolic("aten::len") @_beartype.beartype def _len(g, self): if ( @@ -441,6 +463,7 @@ def _len(g, self): return symbolic_helper._squeeze_helper(g, sz_0, [0]) +@_onnx_symbolic("aten::__getitem_") @_beartype.beartype def __getitem_(g, self, i): if symbolic_helper._is_tensor_list(self): @@ -452,17 +475,20 @@ def __getitem_(g, self, i): return getitem(g, self, i) +@_onnx_symbolic("aten::_set_item") @_beartype.beartype def _set_item(g, tensor_list, i, v): tensor_list = g.op("SequenceErase", tensor_list, i) return g.op("SequenceInsert", tensor_list, v, i) +@_onnx_symbolic("aten::append") @_beartype.beartype def append(g, self, tensor): return g.op("SequenceInsert", self, tensor) +@_onnx_symbolic("aten::add") @_beartype.beartype def add(g, self, other, alpha=None): if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self): @@ -480,21 +506,25 @@ def add(g, self, other, alpha=None): return opset9.add(g, self, other, alpha) +@_onnx_symbolic("aten::insert") @_beartype.beartype def insert(g, self, pos, tensor): return g.op("SequenceInsert", self, tensor, pos) +@_onnx_symbolic("aten::pop") @_beartype.beartype def pop(g, tensor_list, dim): return g.op("SequenceErase", tensor_list, dim) +@_onnx_symbolic("aten::Delete") @_beartype.beartype def Delete(g, tensor_list, dim): return g.op("SequenceErase", tensor_list, dim) +@_onnx_symbolic("aten::cat") @_beartype.beartype def cat(g, tensor_list, dim): if symbolic_helper._is_packed_list(tensor_list): @@ -504,6 +534,7 @@ def cat(g, tensor_list, dim): return g.op("ConcatFromSequence", tensor_list, axis_i=dim) +@_onnx_symbolic("aten::stack") @_beartype.beartype def stack(g, tensor_list, dim): if symbolic_helper._is_packed_list(tensor_list): @@ -513,6 +544,7 @@ def stack(g, tensor_list, dim): return g.op("ConcatFromSequence", tensor_list, axis_i=dim, new_axis_i=1) +@_onnx_symbolic("aten::_unique2") @symbolic_helper.parse_args("v", "i", "i", "i") @_beartype.beartype def _unique2(g, self, sorted, return_inverse, return_counts): @@ -522,6 +554,18 @@ def _unique2(g, self, sorted, return_inverse, return_counts): return u, inverse_indices, counts +@_onnx_symbolic( + "aten::avg_pool1d", + decorate=[_apply_params("avg_pool1d", torch.nn.modules.utils._single)], +) +@_onnx_symbolic( + "aten::avg_pool2d", + decorate=[_apply_params("avg_pool2d", torch.nn.modules.utils._pair)], +) +@_onnx_symbolic( + "aten::avg_pool3d", + decorate=[_apply_params("avg_pool3d", torch.nn.modules.utils._triple)], +) @_beartype.beartype def _avg_pool(name, tuple_fn): @symbolic_helper.quantized_args(True, False, False, False, False, False, False) @@ -564,11 +608,7 @@ def symbolic_fn( return symbolic_fn -avg_pool1d = _avg_pool("avg_pool1d", torch.nn.modules.utils._single) -avg_pool2d = _avg_pool("avg_pool2d", torch.nn.modules.utils._pair) -avg_pool3d = _avg_pool("avg_pool3d", torch.nn.modules.utils._triple) - - +@_onnx_symbolic("aten::unique_dim") @symbolic_helper.parse_args("v", "i", "i", "i", "i") @_beartype.beartype def unique_dim(g, self, dim, sorted, return_inverse, return_counts): @@ -578,6 +618,7 @@ def unique_dim(g, self, dim, sorted, return_inverse, return_counts): return u, inverse_indices, counts +@_onnx_symbolic("aten::topk") @symbolic_helper.parse_args("v", "v", "i", "i", "i", "none") @_beartype.beartype def topk(g, self, k, dim, largest, sorted, out=None): @@ -586,12 +627,14 @@ def topk(g, self, k, dim, largest, sorted, out=None): ) +@_onnx_symbolic("aten::sort") @symbolic_helper.parse_args("v", "i", "i", "none") @_beartype.beartype def sort(g, self, dim, decending, out=None): return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out) +@_onnx_symbolic("aten::argsort") @symbolic_helper.parse_args("v", "i", "i", "none") @_beartype.beartype def argsort(g, self, dim, decending, out=None): @@ -601,11 +644,13 @@ def argsort(g, self, dim, decending, out=None): return indices +@_onnx_symbolic("aten::round") @_beartype.beartype def round(g, self): return g.op("Round", self) +@_onnx_symbolic("aten::remainder") @_beartype.beartype def remainder(g, input, other): if symbolic_helper._is_fp(input) or symbolic_helper._is_fp(other): @@ -613,6 +658,7 @@ def remainder(g, input, other): return g.op("Mod", input, other, fmod_i=0) +@_onnx_symbolic("aten::split") @symbolic_helper.parse_args("v", "v", "i", "i") @_beartype.beartype def split(g, self, split_size_or_sizes, dim, _outputs=None): @@ -651,12 +697,14 @@ def split(g, self, split_size_or_sizes, dim, _outputs=None): return opset9.split(g, self, split_size_or_sizes, dim, _outputs) +@_onnx_symbolic("aten::split_with_sizes") @symbolic_helper.parse_args("v", "v", "i", "i") @_beartype.beartype def split_with_sizes(g, self, split_sizes, dim, _outputs=None): return split(g, self, split_sizes, dim, _outputs) +@_onnx_symbolic("aten::unbind") @symbolic_helper.parse_args("v", "i", "i") @_beartype.beartype def unbind(g, self, dim=0, _outputs=None): @@ -672,14 +720,16 @@ def unbind(g, self, dim=0, _outputs=None): return opset9.unbind(g, self, dim, _outputs) -# Generate paddings in ONNX order based on pad in pytorch. -# Args: -# input: the input tensor. -# pad: the paddings in pytorch. -# The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ..., dim_m_begin, dim_m_end, -# where m is in range [0, n]. @_beartype.beartype def _prepare_onnx_paddings(g, input, pad): + """Generate paddings in ONNX order based on pad in pytorch. + + Args: + input: the input tensor. + pad: the paddings in pytorch. + The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ..., dim_m_begin, dim_m_end, + where m is in range [0, n]. + """ if ( not symbolic_helper._is_packed_list(pad) and symbolic_helper._is_list(pad) @@ -728,6 +778,7 @@ def _prepare_onnx_paddings(g, input, pad): return padding_c +@_onnx_symbolic("aten::constant_pad_nd") @_beartype.beartype def constant_pad_nd(g, input, padding, value=None): mode = "constant" @@ -737,6 +788,9 @@ def constant_pad_nd(g, input, padding, value=None): return g.op("Pad", input, pad, value, mode_s=mode) +@_onnx_symbolic("aten::reflection_pad1d") +@_onnx_symbolic("aten::reflection_pad2d") +@_onnx_symbolic("aten::reflection_pad3d") @_beartype.beartype def reflection_pad(g, input, padding): mode = "reflect" @@ -744,6 +798,9 @@ def reflection_pad(g, input, padding): return g.op("Pad", input, paddings, mode_s=mode) +@_onnx_symbolic("aten::replication_pad1d") +@_onnx_symbolic("aten::replication_pad2d") +@_onnx_symbolic("aten::replication_pad3d") @_beartype.beartype def replication_pad(g, input, padding): mode = "edge" @@ -751,14 +808,7 @@ def replication_pad(g, input, padding): return g.op("Pad", input, paddings, mode_s=mode) -reflection_pad1d = reflection_pad -reflection_pad2d = reflection_pad -reflection_pad3d = reflection_pad -replication_pad1d = replication_pad -replication_pad2d = replication_pad -replication_pad3d = replication_pad - - +@_onnx_symbolic("aten::pad") @_beartype.beartype def pad(g, input, pad, mode, value): mode = symbolic_helper._parse_arg(mode, "s") @@ -774,19 +824,21 @@ def pad(g, input, pad, mode, value): raise errors.SymbolicValueError(f"Unrecognized padding mode {mode}", input) +@_onnx_symbolic("aten::linalg_det") @_beartype.beartype def linalg_det(g, self): return g.op("Det", self) +@_onnx_symbolic("aten::logdet") @_beartype.beartype def logdet(g, input): return opset9.log(g, linalg_det(g, input)) +@_onnx_symbolic("aten::arange") @_beartype.beartype def arange(g, *args): - @_beartype.beartype def _get_arange_dtype(dtype): dtype = symbolic_helper._maybe_get_const(dtype, "i") return dtype @@ -838,6 +890,7 @@ def _get_arange_dtype(dtype): ) +@_onnx_symbolic("aten::_dim_arange") @symbolic_helper.parse_args("v", "i") @_beartype.beartype def _dim_arange(g, like, dim): @@ -850,6 +903,7 @@ def _dim_arange(g, like, dim): return arange(g, stop, 4, None, None, None) +@_onnx_symbolic("aten::size") @_beartype.beartype def size(g, self, dim=None): if dim is None: @@ -857,6 +911,7 @@ def size(g, self, dim=None): return symbolic_helper._size_helper(g, self, dim) +@_onnx_symbolic("aten::squeeze") @_beartype.beartype def squeeze(g, self, dim=None): if dim is None: @@ -909,6 +964,7 @@ def squeeze(g, self, dim=None): return symbolic_helper._squeeze_helper(g, self, [dim]) +@_onnx_symbolic("aten::unsqueeze") @_beartype.beartype def unsqueeze(g, self, dim): if symbolic_helper._is_constant(dim): @@ -917,11 +973,13 @@ def unsqueeze(g, self, dim): return symbolic_helper._unsqueeze_helper(g, self, [dim]) +@_onnx_symbolic("aten::mm") @_beartype.beartype def mm(g, self, other): return g.op("Gemm", self, other, beta_f=0.0, alpha_f=1.0) +@_onnx_symbolic("aten::index") @_beartype.beartype def index(g, self, index): if symbolic_helper.is_caffe2_aten_fallback(): @@ -943,6 +1001,7 @@ def index(g, self, index): return opset9.index(g, self, index) +@_onnx_symbolic("aten::index_fill") @_beartype.beartype def index_fill(g, self, dim, index, value): dim_value = symbolic_helper._parse_arg(dim, "i") @@ -965,6 +1024,7 @@ def index_fill(g, self, dim, index, value): return scatter(g, self, dim, expanded_index, expanded_value) +@_onnx_symbolic("aten::index_copy") @_beartype.beartype def index_copy(g, self, dim, index, source): dim_value = symbolic_helper._parse_arg(dim, "i") @@ -976,6 +1036,7 @@ def index_copy(g, self, dim, index, source): return scatter(g, self, dim, expanded_index, source) +@_onnx_symbolic("aten::__rshift_") @_beartype.beartype def __rshift_(g, self, other): # make sure to cast other to self's type @@ -1006,6 +1067,7 @@ def __rshift_(g, self, other): return rshift +@_onnx_symbolic("aten::__lshift_") @_beartype.beartype def __lshift_(g, self, other): # make sure to cast other to self's type @@ -1106,6 +1168,7 @@ def _get_im2col_output_shape(g, input, kernel_h, kernel_w): ) +@_onnx_symbolic("aten::im2col") @symbolic_helper.parse_args("v", "is", "is", "is", "is") @_beartype.beartype def im2col(g, input, kernel_size, dilation, padding, stride): @@ -1159,6 +1222,7 @@ def im2col(g, input, kernel_size, dilation, padding, stride): return symbolic_helper._reshape_helper(g, output, output_shape) +@_onnx_symbolic("aten::narrow") @_beartype.beartype def narrow(g, input, dim, start, length): end = g.op("Add", start, length) @@ -1167,6 +1231,7 @@ def narrow(g, input, dim, start, length): ) +@_onnx_symbolic("aten::flatten") @symbolic_helper.quantized_args(True, False, False) @symbolic_helper.parse_args("v", "i", "i") @_beartype.beartype @@ -1194,6 +1259,7 @@ def flatten(g, input, start_dim, end_dim): return symbolic_helper._flatten_helper(g, input, start_dim, end_dim, dim) +@_onnx_symbolic("aten::linalg_vector_norm") @symbolic_helper.parse_args("v", "f", "is", "b", "v") @_beartype.beartype def linalg_vector_norm( @@ -1223,6 +1289,7 @@ def linalg_vector_norm( return opset9.linalg_vector_norm(g, self, ord, dim, keepdim, dtype) +@_onnx_symbolic("aten::embedding_bag") @symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") @_beartype.beartype def embedding_bag( @@ -1311,6 +1378,7 @@ def embedding_bag( return loop.node().output(), None, None, None +@_onnx_symbolic("aten::embedding_renorm") @symbolic_helper.parse_args("v", "v", "f", "f") @_beartype.beartype def embedding_renorm(g, weight, indices, max_norm, norm_type): @@ -1350,6 +1418,7 @@ def embedding_renorm(g, weight, indices, max_norm, norm_type): ) +@_onnx_symbolic("aten::chunk") @_beartype.beartype def chunk(g, self, chunks, dim): # Calculate chunk size for dynamic chunk @@ -1367,6 +1436,7 @@ def chunk(g, self, chunks, dim): return split(g, self, chunk_vec, dim) +@_onnx_symbolic("aten::normal") @_beartype.beartype def normal( g, @@ -1390,26 +1460,23 @@ def normal( return add(g, result, mean) -class Prim: - domain = "prim" - - @staticmethod - @_beartype.beartype - def ConstantChunk(g, self, chunks, dim): - input_shape = g.op("Shape", self) - axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) - input_shape_dim = g.op("Gather", input_shape, axis, axis_i=0) - start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) - chunk_size = g.op("Constant", value_t=torch.tensor([chunks], dtype=torch.long)) - chunk_size_minus_1 = g.op( - "Constant", value_t=torch.tensor([chunks - 1], dtype=torch.long) - ) - input_shape_dim_shift = g.op("Add", input_shape_dim, chunk_size_minus_1) - chunk_dim = g.op("Div", input_shape_dim_shift, chunk_size) - res = [] - for i in range(chunks): - index = g.op("Constant", value_t=torch.tensor([i + 1], dtype=torch.long)) - end = g.op("Mul", chunk_dim, index) - res.append(g.op("Slice", self, start, end, axis)) - start = end - return res +@_onnx_symbolic("prim::ConstantChunk") +@_beartype.beartype +def prim_constant_chunk(g, self, chunks, dim): + input_shape = g.op("Shape", self) + axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) + input_shape_dim = g.op("Gather", input_shape, axis, axis_i=0) + start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) + chunk_size = g.op("Constant", value_t=torch.tensor([chunks], dtype=torch.long)) + chunk_size_minus_1 = g.op( + "Constant", value_t=torch.tensor([chunks - 1], dtype=torch.long) + ) + input_shape_dim_shift = g.op("Add", input_shape_dim, chunk_size_minus_1) + chunk_dim = g.op("Div", input_shape_dim_shift, chunk_size) + res = [] + for i in range(chunks): + index = g.op("Constant", value_t=torch.tensor([i + 1], dtype=torch.long)) + end = g.op("Mul", chunk_dim, index) + res.append(g.op("Slice", self, start, end, axis)) + start = end + return res diff --git a/torch/onnx/symbolic_opset12.py b/torch/onnx/symbolic_opset12.py index df1f59a107a21..7f6cada34bbb3 100644 --- a/torch/onnx/symbolic_opset12.py +++ b/torch/onnx/symbolic_opset12.py @@ -1,3 +1,4 @@ +import functools import sys from typing import Optional, Tuple @@ -10,7 +11,7 @@ symbolic_opset9 as opset9, utils, ) -from torch.onnx._internal import _beartype +from torch.onnx._internal import _beartype, registration # EDITING THIS FILE? READ THIS FIRST! @@ -38,6 +39,8 @@ "unfold", ] +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=12) + @_beartype.beartype def _einsum_helper(g, equation, tensors): @@ -58,6 +61,7 @@ def _einsum_helper(g, equation, tensors): return g.op("Einsum", *tensors, equation_s=equation) +@_onnx_symbolic("aten::einsum") @symbolic_helper.parse_args("s", "v") @_beartype.beartype def einsum(g, equation, tensor_list): @@ -65,6 +69,7 @@ def einsum(g, equation, tensor_list): return _einsum_helper(g, equation, tensors) +@_onnx_symbolic("aten::outer") @symbolic_helper.parse_args("v", "v") @_beartype.beartype def outer(g, input, other): @@ -95,6 +100,7 @@ def _dropout_returns_masked_input_and_mask( return r, mask +@_onnx_symbolic("aten::dropout") @symbolic_helper.parse_args("v", "f", "b") @_beartype.beartype def dropout(g, input, p, train): @@ -102,12 +108,14 @@ def dropout(g, input, p, train): return masked +@_onnx_symbolic("aten::native_dropout") @symbolic_helper.parse_args("v", "f", "b") @_beartype.beartype def native_dropout(g, input, p, train): return _dropout_returns_masked_input_and_mask(g, input, p, train) +@_onnx_symbolic("aten::nll_loss") @_beartype.beartype def nll_loss(g, self, target, weight, reduction, ignore_index): # none reduction : onnx::Constant[value={0}] @@ -141,16 +149,19 @@ def nll_loss(g, self, target, weight, reduction, ignore_index): return nllloss +@_onnx_symbolic("aten::nll_loss2d") @_beartype.beartype def nll_loss2d(g, self, target, weight, reduction, ignore_index): return nll_loss(g, self, target, weight, reduction, ignore_index) +@_onnx_symbolic("aten::nll_loss_nd") @_beartype.beartype def nll_loss_nd(g, self, target, weight, reduction, ignore_index): return nll_loss(g, self, target, weight, reduction, ignore_index) +@_onnx_symbolic("aten::cross_entropy_loss") @_beartype.beartype def cross_entropy_loss( g, self, target, weight, reduction, ignore_index, label_smoothing @@ -192,6 +203,7 @@ def cross_entropy_loss( return celoss +@_onnx_symbolic("aten::binary_cross_entropy_with_logits") @symbolic_helper.parse_args("v", "v", "v", "v", "i") @_beartype.beartype def binary_cross_entropy_with_logits(g, input, target, weight, pos_weight, reduction): @@ -235,6 +247,7 @@ def binary_cross_entropy_with_logits(g, input, target, weight, pos_weight, reduc ) +@_onnx_symbolic("aten::celu") @_beartype.beartype def celu(g, self, alpha): alpha = symbolic_helper._maybe_get_const(alpha, "f") @@ -247,33 +260,39 @@ def celu(g, self, alpha): return g.op("Celu", self, alpha_f=alpha) +@_onnx_symbolic("aten::argmax") @symbolic_helper.parse_args("v", "v", "b") @_beartype.beartype def argmax(g, input: torch._C.Value, dim: torch._C.Value, keepdim: bool): return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax") +@_onnx_symbolic("aten::argmin") @symbolic_helper.parse_args("v", "v", "b") @_beartype.beartype def argmin(g, input: torch._C.Value, dim: torch._C.Value, keepdim: bool): return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin") +@_onnx_symbolic("aten::pow") @_beartype.beartype def pow(g, self, exponent): return g.op("Pow", self, exponent) +@_onnx_symbolic("aten::ge") @_beartype.beartype def ge(g, input, other): return g.op("GreaterOrEqual", input, other) +@_onnx_symbolic("aten::le") @_beartype.beartype def le(g, input, other): return g.op("LessOrEqual", input, other) +@_onnx_symbolic("aten::unfold") @symbolic_helper.parse_args("v", "i", "v", "v") @_beartype.beartype def unfold(g, input, dimension, size, step): @@ -344,6 +363,7 @@ def unfold(g, input, dimension, size, step): return symbolic_helper._unimplemented("Unfold", "input size not accessible") +@_onnx_symbolic("aten::tensordot") @symbolic_helper.parse_args("v", "v", "is", "is", "v") @_beartype.beartype def tensordot(g, input_a, input_b, dims_a, dims_b, out=None): diff --git a/torch/onnx/symbolic_opset13.py b/torch/onnx/symbolic_opset13.py index da48195c31c6c..2a889e416ebdb 100644 --- a/torch/onnx/symbolic_opset13.py +++ b/torch/onnx/symbolic_opset13.py @@ -2,6 +2,8 @@ # see Note [Edit Symbolic Files] in symbolic_helper.py # This file exports ONNX ops for opset 13 +import functools + import torch import torch._C._onnx as _C_onnx from torch.onnx import ( @@ -12,9 +14,22 @@ symbolic_opset9 as opset9, utils, ) -from torch.onnx._internal import _beartype +from torch.onnx._internal import _beartype, registration + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=13) + +def _apply_params(*args, **kwargs): + """Returns a decorator that calls the decorated (higher-order) function with the given parameters.""" + def _apply(fn): + return fn(*args, **kwargs) + + return _apply + + +@_onnx_symbolic("aten::softmax") @symbolic_helper.parse_args("v", "i", "none") @_beartype.beartype def softmax(g, input, dim, dtype=None): @@ -28,6 +43,7 @@ def softmax(g, input, dim, dtype=None): return softmax +@_onnx_symbolic("aten::log_softmax") @symbolic_helper.parse_args("v", "i", "none") @_beartype.beartype def log_softmax(g, input, dim, dtype=None): @@ -40,6 +56,7 @@ def log_softmax(g, input, dim, dtype=None): return return_op +@_onnx_symbolic("aten::frobenius_norm") @symbolic_helper.parse_args("v", "v", "i") @_beartype.beartype def frobenius_norm(g, self, dim=None, keepdim=False): @@ -51,6 +68,7 @@ def frobenius_norm(g, self, dim=None, keepdim=False): return g.op("Sqrt", sumsqr) +@_onnx_symbolic("aten::split") @symbolic_helper.parse_args("v", "v", "i", "i") @_beartype.beartype def split(g, self, split_size_or_sizes, dim, _outputs=None): @@ -108,21 +126,25 @@ def split(g, self, split_size_or_sizes, dim, _outputs=None): return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) +@_onnx_symbolic("aten::split_with_sizes") @_beartype.beartype def split_with_sizes(g, self, split_sizes, dim, _outputs=None): return split(g, self, split_sizes, dim, _outputs) +@_onnx_symbolic("aten::unsafe_split") @_beartype.beartype def unsafe_split(g, self, split_size_or_sizes, dim, _outputs=None): return split(g, self, split_size_or_sizes, dim, _outputs) +@_onnx_symbolic("aten::unsafe_split_with_sizes") @_beartype.beartype def unsafe_split_with_sizes(g, self, split_sizes, dim, _outputs=None): return split_with_sizes(g, self, split_sizes, dim, _outputs) +@_onnx_symbolic("aten::tensor_split") @symbolic_helper.parse_args("v", "v", "i", "i") @_beartype.beartype def tensor_split(g, self, indices_or_sections, dim, _outputs=None): @@ -255,6 +277,7 @@ def tensor_split(g, self, indices_or_sections, dim, _outputs=None): return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) +@_onnx_symbolic("aten::unbind") @symbolic_helper.parse_args("v", "i", "i") @_beartype.beartype def unbind(g, self, dim=0, _outputs=None): @@ -277,12 +300,14 @@ def unbind(g, self, dim=0, _outputs=None): return squeezed_outputs +@_onnx_symbolic("aten::nonzero_numpy") # Emitted from `torch.nonzero(x, as_tuple=True)` @_beartype.beartype def nonzero_numpy(g, input, _outputs=None): return unbind(g, opset9.nonzero(g, input), 1, _outputs=_outputs) +@_onnx_symbolic("aten::where") @symbolic_helper.parse_args("v", "v", "v", "i") @_beartype.beartype def where(g, condition, self=None, other=None, _outputs=None): @@ -297,6 +322,7 @@ def where(g, condition, self=None, other=None, _outputs=None): return g.op("Where", condition, self, other) +@_onnx_symbolic("aten::fake_quantize_per_channel_affine") @symbolic_helper.parse_args("v", "v", "v", "i", "i", "i") @_beartype.beartype def fake_quantize_per_channel_affine( @@ -326,6 +352,7 @@ def fake_quantize_per_channel_affine( return g.op("DequantizeLinear", quantized, scale, zero_point, axis_i=axis) +@_onnx_symbolic("aten::fake_quantize_per_tensor_affine") @symbolic_helper.parse_args("v", "v", "v", "i", "i") @_beartype.beartype def fake_quantize_per_tensor_affine( @@ -371,6 +398,10 @@ def symbolic(g, self, dim=None, keepdim=None): return symbolic +@_onnx_symbolic( + "aten::sum", + decorate=[_apply_params("ReduceSum", "sum")], +) @_beartype.beartype def _reduce_with_dtype(onnx_op, name): symbolic = _reduce_op_symbolic(onnx_op) @@ -407,10 +438,7 @@ def reduce_dim(g, self, dim, keepdim, dtype): return reduce -# TODO(justinchuby): Rename the op to avoid colliding with the builtin sum. -sum = _reduce_with_dtype("ReduceSum", "sum") - - +@_onnx_symbolic("aten::unsafe_chunk") @symbolic_helper.parse_args("v", "i", "i", "i") @_beartype.beartype def unsafe_chunk(g, self, chunks, dim, _outputs=None): @@ -439,6 +467,7 @@ def unsafe_chunk(g, self, chunks, dim, _outputs=None): return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) +@_onnx_symbolic("aten::repeat_interleave") @_beartype.beartype def repeat_interleave(g, self, repeats, dim=None, output_size=None): input = self @@ -572,6 +601,7 @@ def repeat_interleave(g, self, repeats, dim=None, output_size=None): return loop_out +@_onnx_symbolic("aten::diagonal") @symbolic_helper.parse_args("v", "i", "i", "i") @_beartype.beartype def diagonal(g, self, offset, dim1, dim2): @@ -689,78 +719,72 @@ def diagonal(g, self, offset, dim1, dim2): return if_op -class Quantized: - """ - https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export - """ +# Quantized ops - domain = "quantized" - @staticmethod - @_beartype.beartype - def linear(g, q_input, q_weight, bias, op_scale, op_zero_point): - input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) - weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) - q_bias = symbolic_helper.requantize_bias_helper( - g, bias, input_scale, weight_scale, axis - ) - bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) +@_onnx_symbolic("quantized::linear") +@_beartype.beartype +def quantized_linear(g, q_input, q_weight, bias, op_scale, op_zero_point): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) - output = opset9.linear(g, input, weight, bias) + output = opset9.linear(g, input, weight, bias) - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - @staticmethod - @_beartype.beartype - def conv2d( - g, - q_input, - q_weight, - bias, - stride, - padding, - dilation, - groups, - op_scale, - op_zero_point, - ): - input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) - weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) - q_bias = symbolic_helper.requantize_bias_helper( - g, bias, input_scale, weight_scale, axis - ) - bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) - output = opset9.conv2d( - g, input, weight, bias, stride, padding, dilation, groups - ) +@_onnx_symbolic("quantized::conv2d") +@_beartype.beartype +def quantized_conv2d( + g, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) - @staticmethod - @_beartype.beartype - def conv2d_relu( - g, - q_input, - q_weight, - bias, - stride, - padding, - dilation, - groups, - op_scale, - op_zero_point, - ): - input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) - weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) - q_bias = symbolic_helper.requantize_bias_helper( - g, bias, input_scale, weight_scale, axis - ) - bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) - output = opset9.conv2d( - g, input, weight, bias, stride, padding, dilation, groups - ) - output = opset9.relu(g, output) - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) +@_onnx_symbolic("quantized::conv2d_relu") +@_beartype.beartype +def quantized_conv2d_relu( + g, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) diff --git a/torch/onnx/symbolic_opset14.py b/torch/onnx/symbolic_opset14.py index 8849f80c20e9f..72e1f7b13af46 100644 --- a/torch/onnx/symbolic_opset14.py +++ b/torch/onnx/symbolic_opset14.py @@ -15,18 +15,24 @@ # EDITING THIS FILE? READ THIS FIRST! # see Note [Edit Symbolic Files] in symbolic_helper.py +import functools + import torch from torch.onnx import symbolic_helper from torch.onnx._globals import GLOBALS -from torch.onnx._internal import _beartype +from torch.onnx._internal import _beartype, registration + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=14) +@_onnx_symbolic("aten::hardswish") @symbolic_helper.parse_args("v") @_beartype.beartype def hardswish(g, self): return g.op("HardSwish", self) +@_onnx_symbolic("aten::tril") @symbolic_helper.parse_args("v", "i") @_beartype.beartype def tril(g, self, diagonal, out=None): @@ -34,6 +40,7 @@ def tril(g, self, diagonal, out=None): return g.op("Trilu", self, k, upper_i=0) +@_onnx_symbolic("aten::triu") @symbolic_helper.parse_args("v", "i") @_beartype.beartype def triu(g, self, diagonal, out=None): @@ -41,6 +48,7 @@ def triu(g, self, diagonal, out=None): return g.op("Trilu", self, k, upper_i=1) +@_onnx_symbolic("aten::reshape") @symbolic_helper.parse_args("v", "v") @_beartype.beartype def reshape(g, self, shape): @@ -49,6 +57,7 @@ def reshape(g, self, shape): return symbolic_helper._reshape_helper(g, self, shape, allowzero=0) +@_onnx_symbolic("aten::batch_norm") @symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i") @_beartype.beartype def batch_norm( @@ -105,18 +114,11 @@ def batch_norm( return res -class Quantized: - """ - https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export - """ - - domain = "quantized" - - @staticmethod - @_beartype.beartype - def hardswish(g, x, op_scale, op_zero_point): - x, _, _, _ = symbolic_helper.dequantize_helper(g, x) +@_onnx_symbolic("quantized::hardswish") +@_beartype.beartype +def quantized_hardswish(g, x, op_scale, op_zero_point): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) - output = hardswish(g, x) + output = hardswish(g, x) - return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) diff --git a/torch/onnx/symbolic_opset15.py b/torch/onnx/symbolic_opset15.py index aba1242acadf7..9ffc3e0fd6d67 100644 --- a/torch/onnx/symbolic_opset15.py +++ b/torch/onnx/symbolic_opset15.py @@ -25,14 +25,19 @@ # EDITING THIS FILE? READ THIS FIRST! # see Note [Edit Symbolic Files] in symbolic_helper.py +import functools + import torch from torch import _C from torch.onnx import symbolic_helper, symbolic_opset9 as opset9 -from torch.onnx._internal import _beartype +from torch.onnx._internal import _beartype, registration + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=15) +@_onnx_symbolic("aten::__is_") @_beartype.beartype -def __is_(g, self, other): +def aten__is_(g, self, other): if symbolic_helper._is_none(other): if isinstance(self.type(), _C.OptionalType): none = g.op("OptionalHasElement", self) @@ -42,22 +47,20 @@ def __is_(g, self, other): return opset9.eq(g, self, other) -@opset9.wrap_logical_op_with_negation +@_onnx_symbolic("aten::__isnot_") +@opset9.wrap_logical_op_with_negation # type: ignore[has-type] @_beartype.beartype -def __isnot_(g, self, other): - return __is_(g, self, other) - +def aten__isnot_(g, self, other): + return aten__is_(g, self, other) -class Prim: - domain = "prim" - @staticmethod - @_beartype.beartype - def unchecked_cast(g, self): - # exists to refine the type of the Value - # if x is Optional[Tensor], unchecked_cast will cast - # x to Tensor, so the rest of the graph knows that x is a Tensor. - if isinstance(self.type(), _C.OptionalType): - return g.op("OptionalGetElement", self) +@_onnx_symbolic("prim::unchecked_cast") +@_beartype.beartype +def prim_unchecked_cast(g, self): + # exists to refine the type of the Value + # if x is Optional[Tensor], unchecked_cast will cast + # x to Tensor, so the rest of the graph knows that x is a Tensor. + if isinstance(self.type(), _C.OptionalType): + return g.op("OptionalGetElement", self) - return self + return self diff --git a/torch/onnx/symbolic_opset16.py b/torch/onnx/symbolic_opset16.py index 601f0aab86e54..bb4f8a9b65db1 100644 --- a/torch/onnx/symbolic_opset16.py +++ b/torch/onnx/symbolic_opset16.py @@ -25,16 +25,21 @@ # EDITING THIS FILE? READ THIS FIRST! # see Note [Edit Symbolic Files] in symbolic_helper.py +import functools + from torch.nn.functional import ( GRID_SAMPLE_INTERPOLATION_MODES, GRID_SAMPLE_PADDING_MODES, ) from torch.onnx import _type_utils, symbolic_helper -from torch.onnx._internal import _beartype +from torch.onnx._internal import _beartype, registration + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=16) # note (mkozuki): Why `grid_sampler` instead of `grid_sample`? # Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`. +@_onnx_symbolic("aten::grid_sampler") @symbolic_helper.parse_args("v", "v", "i", "i", "b") @_beartype.beartype def grid_sampler(g, input, grid, mode_enum, padding_mode_enum, align_corners): @@ -50,6 +55,7 @@ def grid_sampler(g, input, grid, mode_enum, padding_mode_enum, align_corners): ) +@_onnx_symbolic("aten::scatter_add") @symbolic_helper.parse_args("v", "i", "v", "v") @_beartype.beartype def scatter_add(g, self, dim, index, src): diff --git a/torch/onnx/symbolic_opset17.py b/torch/onnx/symbolic_opset17.py index 9ebeb58436ca0..44c866a0c0720 100644 --- a/torch/onnx/symbolic_opset17.py +++ b/torch/onnx/symbolic_opset17.py @@ -15,17 +15,22 @@ SequenceMap """ +import functools from typing import Sequence from torch import _C from torch.onnx import symbolic_helper +from torch.onnx._internal import registration # EDITING THIS FILE? READ THIS FIRST! # see Note [Edit Symbolic Files] in symbolic_helper.py __all__ = ["layer_norm"] +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=17) + +@_onnx_symbolic("aten::layer_norm") @symbolic_helper.parse_args("v", "is", "v", "v", "f", "none") def layer_norm( g, diff --git a/torch/onnx/symbolic_opset7.py b/torch/onnx/symbolic_opset7.py index 7e64933147416..d74f8f1d5e832 100644 --- a/torch/onnx/symbolic_opset7.py +++ b/torch/onnx/symbolic_opset7.py @@ -10,11 +10,16 @@ Scan """ +import functools import warnings from torch.onnx import symbolic_helper, symbolic_opset9 as opset9 +from torch.onnx._internal import registration -block_listed_operators = [ + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=7) + +block_listed_operators = ( "scan", "expand", "expand_as", @@ -25,12 +30,13 @@ "max_pool1d_with_indices", "max_pool2d_with_indices", "max_pool3d_with_indices", -] +) # NOTE: max, min, sum, mean: broadcasting is not supported in opset 7. # torch.max (same for torch.min) actually has two interfaces smashed together: # torch.max(x, dim, keepdim) and torch.max(x, y) +@_onnx_symbolic("aten::max") def max(g, self, dim_or_y=None, keepdim=None): # torch.max(input, other) if keepdim is None and dim_or_y is not None: @@ -42,6 +48,7 @@ def max(g, self, dim_or_y=None, keepdim=None): return opset9.max(g, self, dim_or_y, keepdim) +@_onnx_symbolic("aten::min") def min(g, self, dim_or_y=None, keepdim=None): # torch.min(input, other) if keepdim is None and dim_or_y is not None: @@ -54,5 +61,6 @@ def min(g, self, dim_or_y=None, keepdim=None): for block_listed_op in block_listed_operators: - vars()[block_listed_op] = symbolic_helper._block_list_in_opset(block_listed_op) - vars()[block_listed_op].__module__ = "torch.onnx.symbolic_opset7" + _onnx_symbolic(f"aten::{block_listed_op}")( + symbolic_helper._block_list_in_opset(block_listed_op) + ) diff --git a/torch/onnx/symbolic_opset8.py b/torch/onnx/symbolic_opset8.py index bf480dc27d306..839981d8bc130 100644 --- a/torch/onnx/symbolic_opset8.py +++ b/torch/onnx/symbolic_opset8.py @@ -30,12 +30,16 @@ Scan """ +import functools import warnings import torch from torch.onnx import _type_utils, errors, symbolic_helper, symbolic_opset9 as opset9 +from torch.onnx._internal import registration -block_listed_operators = [ +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=8) + +block_listed_operators = ( "nonzero", "where", "scatter", @@ -49,16 +53,49 @@ "index_fill", "index_copy", "repeat_interleave", - "isnan", "any", "all", -] +) for block_listed_op in block_listed_operators: - vars()[block_listed_op] = symbolic_helper._block_list_in_opset(block_listed_op) - vars()[block_listed_op].__module__ = "torch.onnx.symbolic_opset8" + _onnx_symbolic(f"aten::{block_listed_op}")( + symbolic_helper._block_list_in_opset(block_listed_op) + ) +def _apply_params(*args, **kwargs): + """Returns a decorator that calls the decorated (higher-order) function with the given parameters.""" + + def _apply(fn): + return fn(*args, **kwargs) + + return _apply + + +@_onnx_symbolic( + "aten::upsample_nearest1d", + decorate=[_apply_params("upsample_nearest1d", 3, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_nearest2d", + decorate=[_apply_params("upsample_nearest2d", 4, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_nearest3d", + decorate=[_apply_params("upsample_nearest3d", 5, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_linear1d", + decorate=[_apply_params("upsample_linear1d", 3, "linear")], +) +@_onnx_symbolic( + "aten::upsample_bilinear2d", + decorate=[_apply_params("upsample_bilinear2d", 4, "linear")], +) +@_onnx_symbolic( + "aten::upsample_trilinear3d", + decorate=[_apply_params("upsample_trilinear3d", 5, "linear")], +) def _interpolate(name, dim, interpolate_mode): def symbolic_fn(g, input, output_size, *args): scales, align_corners = symbolic_helper._get_interpolate_attributes( @@ -86,14 +123,7 @@ def symbolic_fn(g, input, output_size, *args): return symbolic_fn -upsample_nearest1d = _interpolate("upsample_nearest1d", 3, "nearest") -upsample_nearest2d = _interpolate("upsample_nearest2d", 4, "nearest") -upsample_nearest3d = _interpolate("upsample_nearest3d", 5, "nearest") -upsample_linear1d = _interpolate("upsample_linear1d", 3, "linear") -upsample_bilinear2d = _interpolate("upsample_bilinear2d", 4, "linear") -upsample_trilinear3d = _interpolate("upsample_trilinear3d", 5, "linear") - - +@_onnx_symbolic("aten::__interpolate") def __interpolate( g, input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias ): @@ -121,7 +151,7 @@ def __interpolate( # issue for "cast" operators. Some symbolic functions depend on shape information of input tensor, which # is lost after casting. def _try_cast_integer_to_float(g, *args): - floating_scalar_types = ["Half", "Float", "Double"] + floating_scalar_types = {"Half", "Float", "Double"} old_type = None # Cast the input tensor to Float if its scalarType is known and is not floating number. # If casting is performed, return the old scalarType, otherwise return None. @@ -160,14 +190,17 @@ def _comparison_operator(g, input, other, op_name): # NOTE: For symbolics {gt, lt, bmm, matmul, prelu, mm, addmm, view, flatten}, # integer input type not supported in opset8. Cast to float if possible. +@_onnx_symbolic("aten::gt") def gt(g, input, other): return _comparison_operator(g, input, other, "Greater") +@_onnx_symbolic("aten::lt") def lt(g, input, other): return _comparison_operator(g, input, other, "Less") +@_onnx_symbolic("aten::bmm") def bmm(g, self, other): if symbolic_helper._try_get_scalar_type(self): old_type, self, other = _try_cast_integer_to_float(g, self, other) @@ -176,10 +209,12 @@ def bmm(g, self, other): return g.op("MatMul", self, other) +@_onnx_symbolic("aten::matmul") def matmul(g, self, other): return bmm(g, self, other) +@_onnx_symbolic("aten::prelu") def prelu(g, self, weight): self_rank = symbolic_helper._get_tensor_rank(self) weight_sizes = symbolic_helper._get_tensor_sizes(weight) @@ -195,6 +230,7 @@ def prelu(g, self, weight): return g.op("PRelu", self, weight) +@_onnx_symbolic("aten::mm") def mm(g, self, other): # Create a dummy C tensor. Only needed for API purposes, the value is # since beta = 0 @@ -222,6 +258,7 @@ def mm(g, self, other): return g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0) +@_onnx_symbolic("aten::addmm") @symbolic_helper.parse_args("v", "v", "v", "t", "t") def addmm(g, self, mat1, mat2, beta, alpha): if symbolic_helper._try_get_scalar_type(self): @@ -249,6 +286,7 @@ def addmm(g, self, mat1, mat2, beta, alpha): ) +@_onnx_symbolic("aten::flatten") def flatten(g, input, start_dim, end_dim): start_dim_i = symbolic_helper._get_const(start_dim, "i", "start_dim") end_dim_i = symbolic_helper._get_const(end_dim, "i", "end_dim") @@ -301,39 +339,46 @@ def _constant_fill(g, sizes, dtype: int, const_value): ) +@_onnx_symbolic("aten::empty") @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") def empty(g, sizes, dtype, layout, device, pin_memory=False, memory_format=None): return zeros(g, sizes, dtype, layout, device, pin_memory) +@_onnx_symbolic("aten::empty_like") @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") def empty_like(g, input, dtype, layout, device, pin_memory=False, memory_format=None): return zeros_like(g, input, dtype, layout, device, pin_memory) +@_onnx_symbolic("aten::zeros") @symbolic_helper.parse_args("v", "i", "v", "v", "v") def zeros(g, sizes, dtype, layout, device, pin_memory=False): # NOTE: no way to set device and layout in ONNX, so we ignore it return _constant_fill(g, sizes, dtype, 0) +@_onnx_symbolic("aten::zeros_like") @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") def zeros_like(g, input, dtype, layout, device, pin_memory=False, memory_format=None): shape = g.op("Shape", input) return _constant_fill(g, shape, dtype, 0) +@_onnx_symbolic("aten::ones") @symbolic_helper.parse_args("v", "i", "v", "v", "v") def ones(g, sizes, dtype, layout, device, pin_memory=False): return _constant_fill(g, sizes, dtype, 1) +@_onnx_symbolic("aten::ones_like") @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") def ones_like(g, input, dtype, layout, device, pin_memory=False, memory_format=None): shape = g.op("Shape", input) return _constant_fill(g, shape, dtype, 1) +@_onnx_symbolic("aten::full") def full(g, sizes, value, dtype, layout, device, pin_memory=False): const_value = symbolic_helper._maybe_get_const(value, "t") if symbolic_helper._is_value(const_value): @@ -344,6 +389,7 @@ def full(g, sizes, value, dtype, layout, device, pin_memory=False): return _constant_fill(g, sizes, dtype, const_value) +@_onnx_symbolic("aten::full_like") @symbolic_helper.parse_args("v", "f", "i", "v", "v", "v", "v") def full_like( g, input, fill_value, dtype, layout, device, pin_memory=False, memory_format=None @@ -352,6 +398,7 @@ def full_like( return _constant_fill(g, shape, dtype, fill_value) +@_onnx_symbolic("aten::repeat") def repeat(g, self, repeats): if not symbolic_helper._is_value(repeats): repeats = g.op("Constant", value_t=torch.LongTensor(repeats)) diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index c240ffaba9c0a..6dc723799ddbc 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -8,7 +8,7 @@ import math import sys import warnings -from typing import List, Optional, Sequence, Tuple, Union +from typing import Callable, List, Optional, Sequence, Tuple, Union import torch import torch._C._onnx as _C_onnx @@ -28,7 +28,7 @@ SymbolicContext, # Special case class import for readability ) from torch.onnx._globals import GLOBALS -from torch.onnx._internal import _beartype +from torch.onnx._internal import _beartype, registration from torch.types import Number # EDITING THIS FILE? READ THIS FIRST! @@ -65,18 +65,10 @@ __all__ = [ "abs", "acos", - "adaptive_avg_pool1d", - "adaptive_avg_pool2d", - "adaptive_avg_pool3d", - "adaptive_max_pool1d", - "adaptive_max_pool2d", - "adaptive_max_pool3d", "add", "addcmul", "addmm", "alias", - "alpha_dropout_", - "alpha_dropout", "amax", "amin", "aminmax", @@ -87,9 +79,6 @@ "as_tensor", "asin", "atan", - "avg_pool1d", - "avg_pool2d", - "avg_pool3d", "baddbmm", "batch_norm", "bernoulli", @@ -106,7 +95,6 @@ "clone", "constant_pad_nd", "contiguous", - "convolution", "conv_tbc", "conv_transpose1d", "conv_transpose2d", @@ -114,6 +102,7 @@ "conv1d", "conv2d", "conv3d", + "convolution", "cos", "cosine_similarity", "cross", @@ -122,7 +111,6 @@ "dim", "div", "dot", - "dropout_", "dropout", "elu", "embedding_bag", @@ -135,10 +123,6 @@ "expand_as", "expand", "eye", - "feature_alpha_dropout_", - "feature_alpha_dropout", - "feature_dropout_", - "feature_dropout", "fill", "flatten", "floor_divide", @@ -153,8 +137,6 @@ "get_pool_ceil_padding", "glu", "group_norm", - "gru", - "gt_impl", "gt", "hann_window", "hardshrink", @@ -169,8 +151,8 @@ "index", "instance_norm", "is_floating_point", - "isnan", "is_pinned", + "isnan", "item", "kl_div", "layer_norm", @@ -196,19 +178,14 @@ "logsumexp", "lstm_cell", "lstm", - "lt_impl", "lt", "masked_fill", "matmul", "max_pool1d_with_indices", - "max_pool1d", "max_pool2d_with_indices", - "max_pool2d", "max_pool3d_with_indices", - "max_pool3d", "max", "maximum", - "mean", "meshgrid", "min", "minimum", @@ -234,8 +211,7 @@ "one_hot", "ones_like", "ones", - "Onnx", - "op_with_optional_float_cast", + "onnx_placeholder", "overload_by_arg_count", "pad", "pairwise_distance", @@ -244,30 +220,38 @@ "pixel_unshuffle", "pow", "prelu", - "Prim", - "prod", + "prim_constant_chunk", + "prim_constant_split", + "prim_constant", + "prim_data", + "prim_device", + "prim_dtype", + "prim_if", + "prim_layout", + "prim_list_construct", + "prim_list_unpack", + "prim_loop", + "prim_max", + "prim_min", + "prim_shape", + "prim_tolist", + "prim_tuple_construct", + "prim_unchecked_cast", + "prim_uninitialized", "rand_like", "rand", "randn_like", "randn", "reciprocal", "reflection_pad", - "reflection_pad1d", - "reflection_pad2d", - "reflection_pad3d", "relu", "relu6", "remainder", "repeat_interleave", "repeat", "replication_pad", - "replication_pad1d", - "replication_pad2d", - "replication_pad3d", "reshape_as", "reshape", - "rnn_relu", - "rnn_tanh", "roll", "rrelu", "rsqrt", @@ -296,7 +280,6 @@ "std_mean", "std", "sub", - "sum", "t", "take", "tan", @@ -316,12 +299,6 @@ "unsafe_split", "unsqueeze", "unused", - "upsample_bilinear2d", - "upsample_linear1d", - "upsample_nearest1d", - "upsample_nearest2d", - "upsample_nearest3d", - "upsample_trilinear3d", "var_mean", "var", "view_as", @@ -335,20 +312,44 @@ _INT64_MAX = 9223372036854775807 +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=9) + + +def _apply_params(*args, **kwargs): + """Returns a decorator that calls the decorated (higher-order) function with the given parameters.""" + + def _apply(fn): + return fn(*args, **kwargs) + + return _apply + + +def _export(name: str): + """Exports the function in the current global namespace.""" + + def wrapper(func): + globals()[name] = func + __all__.append(name) + return func + + return wrapper + -# used to represent "missing" optional inputs @_beartype.beartype def unused(g): + """Represents "missing" optional inputs.""" n = g.op("prim::Constant") n.setType(_C.OptionalType.ofTensor()) return n +@_onnx_symbolic("aten::_shape_as_tensor") @_beartype.beartype def _shape_as_tensor(g, input): return g.op("Shape", input) +@_onnx_symbolic("aten::_reshape_from_tensor") @_beartype.beartype def _reshape_from_tensor(g, input, shape): if isinstance(shape, list): @@ -356,12 +357,14 @@ def _reshape_from_tensor(g, input, shape): return reshape(g, input, shape) +@_onnx_symbolic("aten::reshape") @symbolic_helper.quantized_args(True) @_beartype.beartype def reshape(g, self, shape): return symbolic_helper._reshape_helper(g, self, shape) +@_onnx_symbolic("aten::reshape_as") @symbolic_helper.quantized_args(True) @_beartype.beartype def reshape_as(g, self, other): @@ -369,6 +372,7 @@ def reshape_as(g, self, other): return reshape(g, self, shape) +@_onnx_symbolic("aten::add") @_beartype.beartype def add(g, self, other, alpha=None): if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self): @@ -380,6 +384,7 @@ def add(g, self, other, alpha=None): return g.op("Add", self, other) +@_onnx_symbolic("aten::sub") @_beartype.beartype def sub(g, self, other, alpha=None): if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: @@ -387,11 +392,13 @@ def sub(g, self, other, alpha=None): return g.op("Sub", self, other) +@_onnx_symbolic("aten::rsub") @_beartype.beartype def rsub(g, self, other, alpha=None): return sub(g, other, self, alpha=alpha) +@_onnx_symbolic("aten::mul") @_beartype.beartype def mul(g, self, other): if symbolic_helper._is_bool(self) and symbolic_helper._is_bool(other): @@ -401,6 +408,7 @@ def mul(g, self, other): return g.op("Mul", self, other) +@_onnx_symbolic("aten::div") @_beartype.beartype def div(g, self, other, *args): if len(args) == 0: @@ -409,6 +417,7 @@ def div(g, self, other, *args): return _div_rounding_mode(g, self, other, *args) +@_onnx_symbolic("aten::addcmul") @symbolic_helper.parse_args("v", "v", "v", "f") @_beartype.beartype def addcmul(g, self, tensor1, tensor2, value=1.0): @@ -492,17 +501,20 @@ def _floor_divide(g, self, other): return g.op("Sub", div, fixup) +@_onnx_symbolic("aten::floor_divide") @_beartype.beartype def floor_divide(g, self, other): # Deprecated behavior, floor_divide actually truncates return _trunc_divide(g, self, other) +@_onnx_symbolic("aten::floordiv") @_beartype.beartype def floordiv(g, self, other): return floor_divide(g, self, other) +@_onnx_symbolic("aten::true_divide") @_beartype.beartype def true_divide(g, self, other): """Division where both inputs are cast to floating types @@ -531,6 +543,7 @@ def true_divide(g, self, other): return g.op("Div", self, other) +@_onnx_symbolic("aten::reciprocal") @_beartype.beartype def reciprocal(g, self): # torch.reciprocal implicitly casts to float, so we do the same. @@ -539,6 +552,7 @@ def reciprocal(g, self): return g.op("Reciprocal", self) +@_onnx_symbolic("aten::cat") @symbolic_helper.parse_args("v", "i") @_beartype.beartype def cat(g, tensor_list, dim): @@ -546,6 +560,7 @@ def cat(g, tensor_list, dim): return g.op("Concat", *tensors, axis_i=dim) +@_onnx_symbolic("aten::stack") @symbolic_helper.parse_args("v", "i") @_beartype.beartype def stack(g, tensor_list, dim): @@ -556,11 +571,13 @@ def stack(g, tensor_list, dim): return g.op("Concat", *unsqueezed, axis_i=dim) +@_onnx_symbolic("aten::list") @_beartype.beartype def _list(g, self): return self +@_onnx_symbolic("aten::mm") @_beartype.beartype def mm(g, self, other): # Create a dummy C tensor. Only needed for API purposes, the value is @@ -569,16 +586,19 @@ def mm(g, self, other): return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0) +@_onnx_symbolic("aten::bmm") @_beartype.beartype def bmm(g, self, other): return g.op("MatMul", self, other) +@_onnx_symbolic("aten::matmul") @_beartype.beartype def matmul(g, self, other): return g.op("MatMul", self, other) +@_onnx_symbolic("aten::addmm") @symbolic_helper.parse_args("v", "v", "v", "t", "t") @_beartype.beartype def addmm(g, self, mat1, mat2, beta, alpha): @@ -596,11 +616,12 @@ def addmm(g, self, mat1, mat2, beta, alpha): mat1_rank = symbolic_helper._get_tensor_rank(mat1) mat2_rank = symbolic_helper._get_tensor_rank(mat2) - @_beartype.beartype - def isNotNoneAnd(v, u): + def is_not_none_and(v, u): return v is not None and v != u - if dtype is not None and (isNotNoneAnd(mat1_rank, 2) or isNotNoneAnd(mat2_rank, 2)): + if dtype is not None and ( + is_not_none_and(mat1_rank, 2) or is_not_none_and(mat2_rank, 2) + ): scalar_type = _type_utils.JitScalarType.from_name(dtype) res1 = g.op("MatMul", mat1, mat2) @@ -635,16 +656,19 @@ def isNotNoneAnd(v, u): ) +@_onnx_symbolic("aten::neg") @_beartype.beartype def neg(g, self): return g.op("Neg", self) +@_onnx_symbolic("aten::sqrt") @_beartype.beartype def sqrt(g, self): return g.op("Sqrt", self) +@_onnx_symbolic("aten::rsqrt") @_beartype.beartype def rsqrt(g, self): return g.op( @@ -652,6 +676,7 @@ def rsqrt(g, self): ) +@_onnx_symbolic("aten::tanh") # Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qtanh.cpp @symbolic_helper.quantized_args(True, scale=2.0 / 256.0, zero_point=128) @_beartype.beartype @@ -659,36 +684,43 @@ def tanh(g, self): return g.op("Tanh", self) +@_onnx_symbolic("aten::sin") @_beartype.beartype def sin(g, self): return g.op("Sin", self) +@_onnx_symbolic("aten::cos") @_beartype.beartype def cos(g, self): return g.op("Cos", self) +@_onnx_symbolic("aten::tan") @_beartype.beartype def tan(g, self): return g.op("Tan", self) +@_onnx_symbolic("aten::asin") @_beartype.beartype def asin(g, self): return g.op("Asin", self) +@_onnx_symbolic("aten::acos") @_beartype.beartype def acos(g, self): return g.op("Acos", self) +@_onnx_symbolic("aten::atan") @_beartype.beartype def atan(g, self): return g.op("Atan", self) +@_onnx_symbolic("aten::sigmoid") # Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qsigmoid.cpp @symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0) @_beartype.beartype @@ -696,6 +728,7 @@ def sigmoid(g, self): return g.op("Sigmoid", self) +@_onnx_symbolic("aten::sign") @_beartype.beartype def sign(g, self): return g.op("Sign", self) @@ -758,18 +791,23 @@ def wrapper(g, *args): return wrapper +@_onnx_symbolic("aten::sum", decorate=[_apply_params("ReduceSum", "sum")]) +@_onnx_symbolic("aten::mean", decorate=[_apply_params("ReduceMean", "mean")]) +# torch.prod does not support multidimensional "dim" +@_onnx_symbolic( + "aten::prod", + decorate=[_apply_params("ReduceProd", "prod", allow_multi_dim_support=False)], +) @_beartype.beartype -def _reduce_with_dtype(onnx_op, name, allow_multi_dim_support=True): +def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = True): symbolic = _reduce_op_symbolic( onnx_op, allow_multi_dim_support=allow_multi_dim_support ) @overload_by_arg_count - @_beartype.beartype def reduce(g, *args, **kwargs): @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", "none") - @_beartype.beartype def reduce_nodim(g, self, dtype): if dtype.node().kind() == "onnx::Constant": dtype = symbolic_helper._get_const(dtype, "i", "dtype") @@ -784,7 +822,6 @@ def reduce_nodim(g, self, dtype): @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", dim_desc, "i", "none") # type: ignore[arg-type] - @_beartype.beartype def reduce_dim(g, self, dim, keepdim, dtype): if dtype.node().kind() == "onnx::Constant": dtype = symbolic_helper._get_const(dtype, "i", "dtype") @@ -800,12 +837,7 @@ def reduce_dim(g, self, dim, keepdim, dtype): return reduce -sum = _reduce_with_dtype("ReduceSum", "sum") -mean = _reduce_with_dtype("ReduceMean", "mean") -# torch.prod does not support multidimensional "dim" -prod = _reduce_with_dtype("ReduceProd", "prod", allow_multi_dim_support=False) - - +@_onnx_symbolic("aten::cumsum") @symbolic_helper.parse_args("v", "i", "none") @_beartype.beartype def cumsum(g, input, dim, dtype): @@ -817,6 +849,7 @@ def cumsum(g, input, dim, dtype): symbolic_helper._onnx_opset_unsupported("cumsum", 9, 11, input) +@_onnx_symbolic("aten::_sample_dirichlet") @_beartype.beartype def _sample_dirichlet(g, self, generator): if symbolic_helper.is_caffe2_aten_fallback(): @@ -828,6 +861,7 @@ def _sample_dirichlet(g, self, generator): return symbolic_helper._onnx_unsupported("_sample_dirichlet", self) +@_onnx_symbolic("aten::_standard_gamma") @_beartype.beartype def _standard_gamma(g, self, generator): if symbolic_helper.is_caffe2_aten_fallback(): @@ -840,11 +874,13 @@ def _standard_gamma(g, self, generator): return symbolic_helper._onnx_unsupported("_standard_gamma", self) +@_onnx_symbolic("aten::t") @_beartype.beartype def t(g, self): return g.op("Transpose", self, perm_i=(1, 0)) +@_onnx_symbolic("aten::numpy_T") @symbolic_helper.quantized_args(True) @_beartype.beartype def numpy_T(g, input): @@ -854,6 +890,7 @@ def numpy_T(g, input): return g.op("Transpose", input, perm_i=perm) +@_onnx_symbolic("aten::expand") @symbolic_helper.quantized_args(True) @_beartype.beartype def expand(g, self, size, implicit): @@ -874,6 +911,7 @@ def expand(g, self, size, implicit): return g.op("Expand", self, size) +@_onnx_symbolic("aten::expand_as") @symbolic_helper.quantized_args(True, True) @_beartype.beartype def expand_as(g, self, other): @@ -891,6 +929,7 @@ def expand_as(g, self, other): return g.op("Expand", self, shape) +@_onnx_symbolic("aten::embedding") @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", "v", "i", "b", "v") @_beartype.beartype @@ -911,6 +950,7 @@ def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse): return g.op("Gather", weight, indices) +@_onnx_symbolic("aten::embedding_bag") @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") @_beartype.beartype @@ -947,6 +987,7 @@ def embedding_bag( return symbolic_helper._onnx_unsupported("embedding_bag", embedding_matrix) +@_onnx_symbolic("aten::size") @_beartype.beartype def size(g, self, dim=None): if dim is None: @@ -959,6 +1000,7 @@ def size(g, self, dim=None): return symbolic_helper._size_helper(g, self, dim) +@_onnx_symbolic("aten::transpose") @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", "i", "i") @_beartype.beartype @@ -983,6 +1025,7 @@ def transpose(g, self, dim0, dim1): ) +@_onnx_symbolic("aten::permute") @symbolic_helper.parse_args("v", "is") @_beartype.beartype def permute(g, self, dims): @@ -991,18 +1034,21 @@ def permute(g, self, dims): return g.op("Transpose", self, perm_i=dims) +@_onnx_symbolic("aten::view") @symbolic_helper.quantized_args(True) @_beartype.beartype def view(g, self, size): return reshape(g, self, size) +@_onnx_symbolic("aten::view_as") @_beartype.beartype def view_as(g, self, other): shape = g.op("Shape", other) return reshape(g, self, shape) +@_onnx_symbolic("aten::unsafe_chunk") @symbolic_helper.parse_args("v", "i", "i", "i") @_beartype.beartype def unsafe_chunk(g, self, chunks, dim, _outputs=None): @@ -1023,6 +1069,7 @@ def unsafe_chunk(g, self, chunks, dim, _outputs=None): return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs) +@_onnx_symbolic("aten::split") @symbolic_helper.parse_args("v", "v", "i", "i") @_beartype.beartype def split(g, self, split_size_or_sizes, dim, _outputs=None): @@ -1050,11 +1097,13 @@ def split(g, self, split_size_or_sizes, dim, _outputs=None): return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs) +@_onnx_symbolic("aten::unsafe_split") @_beartype.beartype def unsafe_split(g, self, split_size_or_sizes, dim, _outputs=None): return split(g, self, split_size_or_sizes, dim, _outputs) +@_onnx_symbolic("aten::split_with_sizes") @symbolic_helper.parse_args("v", "is", "i", "i") @_beartype.beartype def split_with_sizes(g, self, split_sizes, dim, _outputs=None): @@ -1065,11 +1114,13 @@ def split_with_sizes(g, self, split_sizes, dim, _outputs=None): return g.op("Split", self, split_i=split_sizes, axis_i=dim, outputs=_outputs) +@_onnx_symbolic("aten::unsafe_split_with_sizes") @_beartype.beartype def unsafe_split_with_sizes(g, self, split_sizes, dim, _outputs=None): return split_with_sizes(g, self, split_sizes, dim, _outputs) +@_onnx_symbolic("aten::unbind") @symbolic_helper.parse_args("v", "i", "i") @_beartype.beartype def unbind(g, self, dim=0, _outputs=None): @@ -1086,6 +1137,7 @@ def unbind(g, self, dim=0, _outputs=None): return squeezed_outputs +@_onnx_symbolic("aten::select") @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", "i", "v") @_beartype.beartype @@ -1105,11 +1157,13 @@ def select(g, self, dim, index): return g.op("Gather", self, index, axis_i=dim) +@_onnx_symbolic("aten::square") @_beartype.beartype def square(g, self): return g.op("Mul", self, self) +@_onnx_symbolic("aten::squeeze") @_beartype.beartype def squeeze(g, self, dim=None): if dim is None: @@ -1173,6 +1227,7 @@ def squeeze(g, self, dim=None): return symbolic_helper._squeeze_helper(g, self, axes_i=[squeeze_dim]) +@_onnx_symbolic("aten::prelu") @_beartype.beartype def prelu(g, self, weight): self_rank = symbolic_helper._get_tensor_rank(self) @@ -1196,18 +1251,20 @@ def prelu(g, self, weight): return g.op("PRelu", self, weight) +@_onnx_symbolic("aten::silu") @_beartype.beartype def silu(g, input): return g.op("Mul", input, g.op("Sigmoid", input)) +@_onnx_symbolic("aten::mish") @_beartype.beartype def mish(g, input): return g.op("Mul", input, g.op("Tanh", g.op("Softplus", input))) @_beartype.beartype -def op_with_optional_float_cast(g, op_name, *args, **kwargs): +def _op_with_optional_float_cast(g, op_name, *args, **kwargs): """Some PyTorch operators (e.g., Clip/Min/ReLU/Pad) are super set of ONNX in terms of data types. This function maximizes the exportability of PyTorch-ONNX by allowing ONNX-unsupported PyTorch operator data type. For example, `Cast(Clip(Cast(INPUT)))` can be used to mimic @@ -1261,35 +1318,41 @@ def op_with_optional_float_cast(g, op_name, *args, **kwargs): return self +@_onnx_symbolic("aten::relu") @symbolic_helper.quantized_args(True) @_beartype.beartype def relu(g, input): - return op_with_optional_float_cast(g, "Relu", input, opset_before=14) + return _op_with_optional_float_cast(g, "Relu", input, opset_before=14) +@_onnx_symbolic("aten::relu6") @symbolic_helper.quantized_args(True) @_beartype.beartype def relu6(g, input): - relu = op_with_optional_float_cast(g, "Relu", input, opset_before=14) + relu = _op_with_optional_float_cast(g, "Relu", input, opset_before=14) return clamp_max(g, relu, 6) +@_onnx_symbolic("aten::ceil") @_beartype.beartype def ceil(g, input): return g.op("Ceil", input) +@_onnx_symbolic("aten::floor") @_beartype.beartype def floor(g, input): return g.op("Floor", input) +@_onnx_symbolic("aten::len") @_beartype.beartype def _len(g, self): sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0]))) return symbolic_helper._squeeze_helper(g, sz_0, [0]) +@_onnx_symbolic("aten::threshold") @symbolic_helper.parse_args("v", "t", "t") @_beartype.beartype def threshold(g, self, threshold, value): @@ -1301,6 +1364,7 @@ def threshold(g, self, threshold, value): return g.op("Relu", self) +@_onnx_symbolic("aten::leaky_relu") @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", "f", "b") @_beartype.beartype @@ -1309,6 +1373,7 @@ def leaky_relu(g, input: _C.Value, negative_slope: float, inplace: bool = False) return g.op("LeakyRelu", input, alpha_f=negative_slope) +@_onnx_symbolic("aten::glu") @symbolic_helper.parse_args("v", "i") @_beartype.beartype def glu(g, input, dim): @@ -1320,6 +1385,7 @@ def glu(g, input, dim): return g.op("Mul", first, g.op("Sigmoid", second)) +@_onnx_symbolic("aten::softmax") @symbolic_helper.parse_args("v", "i", "none") @_beartype.beartype def softmax(g, input, dim, dtype=None): @@ -1384,6 +1450,7 @@ def softmax(g, input, dim, dtype=None): return softmax +@_onnx_symbolic("aten::softplus") @_beartype.beartype def softplus(g, self, beta, threshold): beta_const = symbolic_helper._maybe_get_const(beta, "f") @@ -1392,8 +1459,10 @@ def softplus(g, self, beta, threshold): return g.op("Softplus", self) +@_onnx_symbolic("aten::get_pool_ceil_padding") @_beartype.beartype def get_pool_ceil_padding(input, kernel_size, stride, padding): + # TODO(justinchuby): Looks like this op is deprecated in torch sizes = symbolic_helper._get_tensor_sizes(input) dim = sizes[-len(padding) :] if sizes is not None else None if dim is None or any([i is None for i in dim]): @@ -1435,6 +1504,33 @@ def get_pool_ceil_padding(input, kernel_size, stride, padding): return padding_ceil +@_onnx_symbolic( + "aten::max_pool1d", + decorate=[ + _apply_params( + "max_pool1d", torch.nn.modules.utils._single, 1, return_indices=False + ), + _export("max_pool1d"), + ], +) +@_onnx_symbolic( + "aten::max_pool2d", + decorate=[ + _apply_params( + "max_pool2d", torch.nn.modules.utils._pair, 2, return_indices=False + ), + _export("max_pool2d"), + ], +) +@_onnx_symbolic( + "aten::max_pool3d", + decorate=[ + _apply_params( + "max_pool3d", torch.nn.modules.utils._triple, 3, return_indices=False + ), + _export("max_pool3d"), + ], +) @_beartype.beartype def _max_pool(name, tuple_fn, ndims, return_indices): @symbolic_helper.quantized_args(True, False, False, False, False, False) @@ -1495,35 +1591,53 @@ def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode): return symbolic_fn -max_pool1d = _max_pool( - "max_pool1d", torch.nn.modules.utils._single, 1, return_indices=False +max_pool1d_with_indices = _onnx_symbolic("aten::max_pool1d_with_indices")( + _max_pool( + "max_pool1d_with_indices", + torch.nn.modules.utils._single, + 1, + return_indices=True, + ) ) -max_pool2d = _max_pool( - "max_pool2d", torch.nn.modules.utils._pair, 2, return_indices=False +max_pool2d_with_indices = _onnx_symbolic("aten::max_pool2d_with_indices")( + _max_pool( + "max_pool2d_with_indices", + torch.nn.modules.utils._pair, + 2, + return_indices=True, + ) ) -max_pool3d = _max_pool( - "max_pool3d", torch.nn.modules.utils._triple, 3, return_indices=False +max_pool3d_with_indices = _onnx_symbolic("aten::max_pool3d_with_indices")( + _max_pool( + "max_pool3d_with_indices", + torch.nn.modules.utils._triple, + 3, + return_indices=True, + ) ) -max_pool1d_with_indices = _max_pool( - "max_pool1d_with_indices", - torch.nn.modules.utils._single, - 1, - return_indices=True, + + +@_onnx_symbolic( + "aten::avg_pool1d", + decorate=[ + _apply_params("avg_pool1d", torch.nn.modules.utils._single), + _export("avg_pool1d"), + ], ) -max_pool2d_with_indices = _max_pool( - "max_pool2d_with_indices", - torch.nn.modules.utils._pair, - 2, - return_indices=True, +@_onnx_symbolic( + "aten::avg_pool2d", + decorate=[ + _apply_params("avg_pool2d", torch.nn.modules.utils._pair), + _export("avg_pool2d"), + ], ) -max_pool3d_with_indices = _max_pool( - "max_pool3d_with_indices", - torch.nn.modules.utils._triple, - 3, - return_indices=True, +@_onnx_symbolic( + "aten::avg_pool3d", + decorate=[ + _apply_params("avg_pool3d", torch.nn.modules.utils._triple), + _export("avg_pool3d"), + ], ) - - @_beartype.beartype def _avg_pool(name, tuple_fn): @symbolic_helper.quantized_args(True) @@ -1574,11 +1688,69 @@ def symbolic_fn( return symbolic_fn -avg_pool1d = _avg_pool("avg_pool1d", torch.nn.modules.utils._single) -avg_pool2d = _avg_pool("avg_pool2d", torch.nn.modules.utils._pair) -avg_pool3d = _avg_pool("avg_pool3d", torch.nn.modules.utils._triple) - - +@_onnx_symbolic( + "aten::adaptive_avg_pool1d", + decorate=[ + _apply_params( + "adaptive_avg_pool1d", "AveragePool", torch.nn.modules.utils._single + ), + _export("adaptive_avg_pool1d"), + ], +) +@_onnx_symbolic( + "aten::adaptive_avg_pool2d", + decorate=[ + _apply_params( + "adaptive_avg_pool2d", "AveragePool", torch.nn.modules.utils._pair + ), + _export("adaptive_avg_pool2d"), + ], +) +@_onnx_symbolic( + "aten::adaptive_avg_pool3d", + decorate=[ + _apply_params( + "adaptive_avg_pool3d", "AveragePool", torch.nn.modules.utils._triple + ), + _export("adaptive_avg_pool3d"), + ], +) +@_onnx_symbolic( + "aten::adaptive_max_pool1d", + decorate=[ + _apply_params( + "adaptive_max_pool1d", + "MaxPool", + torch.nn.modules.utils._single, + max_pool1d_with_indices, + ), + _export("adaptive_max_pool1d"), + ], +) +@_onnx_symbolic( + "aten::adaptive_max_pool2d", + decorate=[ + _apply_params( + "adaptive_max_pool2d", + "MaxPool", + torch.nn.modules.utils._pair, + max_pool2d_with_indices, + ), + _export("adaptive_max_pool2d"), + ], +) +@_onnx_symbolic( + "aten::adaptive_max_pool3d", + decorate=[ + _apply_params( + "adaptive_max_pool3d", + "MaxPool", + torch.nn.modules.utils._triple, + max_pool3d_with_indices, + ), + _export("adaptive_max_pool3d"), + ], +) @_beartype.beartype def _adaptive_pool(name, type, tuple_fn, fn=None): @symbolic_helper.quantized_args(True, False) @@ -1635,44 +1807,14 @@ def symbolic_fn(g, input, output_size): return symbolic_fn -adaptive_avg_pool1d = _adaptive_pool( - "adaptive_avg_pool1d", "AveragePool", torch.nn.modules.utils._single -) -adaptive_avg_pool2d = _adaptive_pool( - "adaptive_avg_pool2d", "AveragePool", torch.nn.modules.utils._pair -) -adaptive_avg_pool3d = _adaptive_pool( - "adaptive_avg_pool3d", "AveragePool", torch.nn.modules.utils._triple -) - -adaptive_max_pool1d = _adaptive_pool( - "adaptive_max_pool1d", - "MaxPool", - torch.nn.modules.utils._single, - max_pool1d_with_indices, -) -adaptive_max_pool2d = _adaptive_pool( - "adaptive_max_pool2d", - "MaxPool", - torch.nn.modules.utils._pair, - max_pool2d_with_indices, -) -adaptive_max_pool3d = _adaptive_pool( - "adaptive_max_pool3d", - "MaxPool", - torch.nn.modules.utils._triple, - max_pool3d_with_indices, -) - - -# Generate paddings in ONNX order based on pad in pytorch. -# Args: -# dim: the dimension of the tensor. -# pad: the paddings in pytorch. -# The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ... @_beartype.beartype -def _prepare_onnx_paddings(dim, pad): - assert isinstance(dim, int) +def _prepare_onnx_paddings(dim: int, pad): + """Generate paddings in ONNX order based on pad in pytorch. + Args: + dim: the dimension of the tensor. + pad: the paddings in pytorch. + The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ... + """ # The desired order of paddings is # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end. # n is the dimension of input. @@ -1701,6 +1843,7 @@ def _convert_padding_node(input): return padding +@_onnx_symbolic("aten::constant_pad_nd") @_beartype.beartype def constant_pad_nd(g, input, padding, value): mode = "constant" @@ -1715,11 +1858,12 @@ def constant_pad_nd(g, input, padding, value): padding = _convert_padding_node(padding) paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) - return op_with_optional_float_cast( + return _op_with_optional_float_cast( g, "Pad", input, pads_i=paddings, mode_s=mode, value_f=value, opset_before=11 ) +@_onnx_symbolic("aten::_pad_circular") @_beartype.beartype def _pad_circular(g, input, pad): padding = _convert_padding_node(pad) @@ -1761,34 +1905,33 @@ def _pad_circular(g, input, pad): return cur +@_onnx_symbolic("aten::reflection_pad1d") +@_onnx_symbolic("aten::reflection_pad2d") +@_onnx_symbolic("aten::reflection_pad3d") @_beartype.beartype def reflection_pad(g, input, padding): mode = "reflect" padding = _convert_padding_node(padding) paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) - return op_with_optional_float_cast( + return _op_with_optional_float_cast( g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11 ) +@_onnx_symbolic("aten::replication_pad1d") +@_onnx_symbolic("aten::replication_pad2d") +@_onnx_symbolic("aten::replication_pad3d") @_beartype.beartype def replication_pad(g, input, padding): mode = "edge" padding = _convert_padding_node(padding) paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) - return op_with_optional_float_cast( + return _op_with_optional_float_cast( g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11 ) -reflection_pad1d = reflection_pad -reflection_pad2d = reflection_pad -reflection_pad3d = reflection_pad -replication_pad1d = replication_pad -replication_pad2d = replication_pad -replication_pad3d = replication_pad - - +@_onnx_symbolic("aten::pad") @_beartype.beartype def pad(g, input, pad, mode, value): mode = symbolic_helper._parse_arg(mode, "s") @@ -1804,9 +1947,50 @@ def pad(g, input, pad, mode, value): raise errors.SymbolicValueError(f"Unrecognized padding mode {mode}", input) +@_onnx_symbolic( + "aten::upsample_nearest1d", + decorate=[ + _apply_params("upsample_nearest1d", 3, "nearest"), + _export("upsample_nearest1d"), + ], +) +@_onnx_symbolic( + "aten::upsample_nearest2d", + decorate=[ + _apply_params("upsample_nearest2d", 4, "nearest"), + _export("upsample_nearest2d"), + ], +) +@_onnx_symbolic( + "aten::upsample_nearest3d", + decorate=[ + _apply_params("upsample_nearest3d", 5, "nearest"), + _export("upsample_nearest3d"), + ], +) +@_onnx_symbolic( + "aten::upsample_linear1d", + decorate=[ + _apply_params("upsample_linear1d", 3, "linear"), + _export("upsample_linear1d"), + ], +) +@_onnx_symbolic( + "aten::upsample_bilinear2d", + decorate=[ + _apply_params("upsample_bilinear2d", 4, "linear"), + _export("upsample_bilinear2d"), + ], +) +@_onnx_symbolic( + "aten::upsample_trilinear3d", + decorate=[ + _apply_params("upsample_trilinear3d", 5, "linear"), + _export("upsample_trilinear3d"), + ], +) @_beartype.beartype -def _interpolate(name, dim, interpolate_mode): - @_beartype.beartype +def _interpolate(name: str, dim: int, interpolate_mode: str): def symbolic_fn(g, input, output_size, *args): scales, align_corners = symbolic_helper._get_interpolate_attributes( g, interpolate_mode, args @@ -1824,14 +2008,7 @@ def symbolic_fn(g, input, output_size, *args): return symbolic_fn -upsample_nearest1d = _interpolate("upsample_nearest1d", 3, "nearest") -upsample_nearest2d = _interpolate("upsample_nearest2d", 4, "nearest") -upsample_nearest3d = _interpolate("upsample_nearest3d", 5, "nearest") -upsample_linear1d = _interpolate("upsample_linear1d", 3, "linear") -upsample_bilinear2d = _interpolate("upsample_bilinear2d", 4, "linear") -upsample_trilinear3d = _interpolate("upsample_trilinear3d", 5, "linear") - - +@_onnx_symbolic("aten::__interpolate") @_beartype.beartype def __interpolate( g, input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias @@ -1842,6 +2019,7 @@ def __interpolate( return g.op("Upsample", input, scales, mode_s=mode) +@_onnx_symbolic("aten::bitwise_not") @_beartype.beartype def bitwise_not(g, input): if not symbolic_helper._is_bool(input): @@ -1855,9 +2033,8 @@ def bitwise_not(g, input): @_beartype.beartype def wrap_logical_op_with_cast_to(to_type): - @_beartype.beartype def decorator(fn): - @_beartype.beartype + @functools.wraps(fn) def wrap_with_cast(g, input, other): to_cast_func = globals()[f"_cast_{to_type}"] return fn(g, to_cast_func(g, input, False), to_cast_func(g, other, False)) @@ -1868,14 +2045,15 @@ def wrap_with_cast(g, input, other): @_beartype.beartype -def wrap_logical_op_with_negation(func): - @_beartype.beartype +def wrap_logical_op_with_negation(func: Callable) -> Callable: + @functools.wraps(func) def wrap_with_not(g, input, other): return g.op("Not", func(g, input, other)) return wrap_with_not +@_onnx_symbolic("aten::__not_") @_beartype.beartype def __not_(g, self): if not symbolic_helper._is_bool(self): @@ -1887,6 +2065,7 @@ def __not_(g, self): return g.op("Not", self) +@_onnx_symbolic("aten::eq") @symbolic_helper.quantized_args(True, True) @_beartype.beartype def eq(g, self, other): @@ -1899,6 +2078,7 @@ def eq(g, self, other): return g.op("Equal", self, other) +@_onnx_symbolic("aten::ne") @symbolic_helper.quantized_args(True, True) @wrap_logical_op_with_negation @_beartype.beartype @@ -1906,14 +2086,15 @@ def ne(g, self, other): return eq(g, self, other) +@_onnx_symbolic("aten::gt") @symbolic_helper.quantized_args(True, True) @_beartype.beartype def gt(g, input, other): - return gt_impl(g, input, other) + return _gt_impl(g, input, other) @_beartype.beartype -def gt_impl(g, input, other): +def _gt_impl(g, input, other): if ( input.type().scalarType() is not None and symbolic_helper._is_bool(input) @@ -1925,14 +2106,15 @@ def gt_impl(g, input, other): return g.op("Greater", input, other) +@_onnx_symbolic("aten::lt") @symbolic_helper.quantized_args(True, True) @_beartype.beartype def lt(g, input, other): - return lt_impl(g, input, other) + return _lt_impl(g, input, other) @_beartype.beartype -def lt_impl(g, input, other): +def _lt_impl(g, input, other): if ( input.type().scalarType() is not None and symbolic_helper._is_bool(input) @@ -1944,20 +2126,23 @@ def lt_impl(g, input, other): return g.op("Less", input, other) +@_onnx_symbolic("aten::ge") @symbolic_helper.quantized_args(True, True) @wrap_logical_op_with_negation @_beartype.beartype def ge(g, input, other): - return lt_impl(g, input, other) + return _lt_impl(g, input, other) +@_onnx_symbolic("aten::le") @symbolic_helper.quantized_args(True, True) @wrap_logical_op_with_negation @_beartype.beartype def le(g, input, other): - return gt_impl(g, input, other) + return _gt_impl(g, input, other) +@_onnx_symbolic("aten::__and_") @_beartype.beartype def __and_(g, input, other): if not symbolic_helper._is_bool(input): @@ -1975,6 +2160,7 @@ def __and_(g, input, other): return g.op("And", input, other) +@_onnx_symbolic("aten::__or_") @_beartype.beartype def __or_(g, input, other): if not symbolic_helper._is_bool(input): @@ -1992,6 +2178,7 @@ def __or_(g, input, other): return g.op("Or", input, other) +@_onnx_symbolic("aten::__xor_") @_beartype.beartype def __xor_(g, input, other): if not symbolic_helper._is_bool(input): @@ -2009,24 +2196,28 @@ def __xor_(g, input, other): return g.op("Xor", input, other) +@_onnx_symbolic("aten::logical_and") @wrap_logical_op_with_cast_to("Bool") @_beartype.beartype def logical_and(g, input, other): return g.op("And", input, other) +@_onnx_symbolic("aten::logical_or") @wrap_logical_op_with_cast_to("Bool") @_beartype.beartype def logical_or(g, input, other): return g.op("Or", input, other) +@_onnx_symbolic("aten::logical_xor") @wrap_logical_op_with_cast_to("Bool") @_beartype.beartype def logical_xor(g, input, other): return g.op("Xor", input, other) +@_onnx_symbolic("aten::__rshift_") @_beartype.beartype def __rshift_(g, self, other): # make sure to cast other to self's type @@ -2054,6 +2245,7 @@ def __rshift_(g, self, other): return rshift +@_onnx_symbolic("aten::__lshift_") @_beartype.beartype def __lshift_(g, self, other): # make sure to cast other to self's type @@ -2081,6 +2273,7 @@ def __lshift_(g, self, other): return lshift +@_onnx_symbolic("aten::where") @symbolic_helper.parse_args("v", "v", "v", "i") @_beartype.beartype def where(g, condition, self=None, other=None, _outputs=None): @@ -2095,6 +2288,7 @@ def where(g, condition, self=None, other=None, _outputs=None): return g.op("Where", condition, self, other) +@_onnx_symbolic("aten::log_softmax") @symbolic_helper.parse_args("v", "i", "none") @_beartype.beartype def log_softmax(g, input, dim, dtype=None): @@ -2128,6 +2322,7 @@ def log_softmax(g, input, dim, dtype=None): return return_op +@_onnx_symbolic("aten::_log_softmax") @symbolic_helper.parse_args("v", "i", "i") @_beartype.beartype def _log_softmax(g, input, dim, half_to_float): @@ -2136,6 +2331,7 @@ def _log_softmax(g, input, dim, half_to_float): return log_softmax(g, input, dim) +@_onnx_symbolic("aten::_convolution") @symbolic_helper.parse_args( "v", "v", "v", "is", "is", "is", "i", "is", "i", "i", "i", "i", "i" ) @@ -2207,6 +2403,7 @@ def _convolution( return n +@_onnx_symbolic("aten::convolution") @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is", "i") @_beartype.beartype def convolution( @@ -2239,6 +2436,7 @@ def convolution( ) +@_onnx_symbolic("aten::conv1d") @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i") @_beartype.beartype def conv1d(g, input, weight, bias, stride, padding, dilation, groups): @@ -2260,6 +2458,7 @@ def conv1d(g, input, weight, bias, stride, padding, dilation, groups): ) +@_onnx_symbolic("aten::conv2d") @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i") @_beartype.beartype def conv2d(g, input, weight, bias, stride, padding, dilation, groups): @@ -2281,6 +2480,7 @@ def conv2d(g, input, weight, bias, stride, padding, dilation, groups): ) +@_onnx_symbolic("aten::conv3d") @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i") @_beartype.beartype def conv3d(g, input, weight, bias, stride, padding, dilation, groups): @@ -2302,6 +2502,7 @@ def conv3d(g, input, weight, bias, stride, padding, dilation, groups): ) +@_onnx_symbolic("aten::conv_transpose1d") @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") @_beartype.beartype def conv_transpose1d( @@ -2325,6 +2526,7 @@ def conv_transpose1d( ) +@_onnx_symbolic("aten::conv_transpose2d") @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") @_beartype.beartype def conv_transpose2d( @@ -2348,6 +2550,7 @@ def conv_transpose2d( ) +@_onnx_symbolic("aten::conv_transpose3d") @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") @_beartype.beartype def conv_transpose3d( @@ -2371,6 +2574,7 @@ def conv_transpose3d( ) +@_onnx_symbolic("aten::batch_norm") @symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i") @_beartype.beartype def batch_norm( @@ -2474,6 +2678,7 @@ def _layer_norm_returns_normalized_input_mean_rstd( return normalized, None, None +@_onnx_symbolic("aten::native_layer_norm") @symbolic_helper.quantized_args(True, False, False, False) @symbolic_helper.parse_args("v", "is", "v", "v", "f") @_beartype.beartype @@ -2483,6 +2688,7 @@ def native_layer_norm(g, input, normalized_shape, weight, bias, eps): ) +@_onnx_symbolic("aten::layer_norm") @symbolic_helper.quantized_args(True, False, False, False) @symbolic_helper.parse_args("v", "is", "v", "v", "f", "b") @_beartype.beartype @@ -2493,6 +2699,7 @@ def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable): return normalized +@_onnx_symbolic("aten::instance_norm") @symbolic_helper.parse_args("v", "v", "v", "v", "v", "b", "f", "f", "b") @_beartype.beartype def instance_norm( @@ -2594,6 +2801,7 @@ def instance_norm( return view(g, out, g.op("Constant", value_t=torch.tensor(input_size))) +@_onnx_symbolic("aten::unfold") @symbolic_helper.parse_args("v", "i", "i", "i") @_beartype.beartype def unfold(g, input, dimension, size, step): @@ -2632,6 +2840,7 @@ def unfold(g, input, dimension, size, step): ) +@_onnx_symbolic("aten::elu") @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", "t", "t", "t") @_beartype.beartype @@ -2648,12 +2857,14 @@ def elu(g, input, alpha, scale, input_scale): return g.op("Elu", input, alpha_f=symbolic_helper._scalar(alpha)) +@_onnx_symbolic("aten::selu") @symbolic_helper.quantized_args(True) @_beartype.beartype def selu(g, input): return g.op("Selu", input) +@_onnx_symbolic("aten::index_select") @symbolic_helper.parse_args("v", "i", "v") @_beartype.beartype def index_select(g, self, dim, index): @@ -2663,6 +2874,7 @@ def index_select(g, self, dim, index): return symbolic_helper._select_helper(g, self, dim, index) +@_onnx_symbolic("aten::index_put") @_beartype.beartype def index_put(g, self, indices_list_value, values, accumulate): if symbolic_helper._is_packed_list(indices_list_value): @@ -2682,6 +2894,7 @@ def index_put(g, self, indices_list_value, values, accumulate): symbolic_helper._onnx_opset_unsupported("index_put", 9, 11, self) +@_onnx_symbolic("aten::index_fill") @_beartype.beartype def index_fill(g, self, dim, index, value): dim_value = symbolic_helper._parse_arg(dim, "i") @@ -2705,6 +2918,7 @@ def index_fill(g, self, dim, index, value): return scatter(g, self, dim, expanded_index, expanded_value) +@_onnx_symbolic("aten::index_copy") @_beartype.beartype def index_copy(g, self, dim, index, source): dim_value = symbolic_helper._parse_arg(dim, "i") @@ -2716,6 +2930,7 @@ def index_copy(g, self, dim, index, source): return scatter(g, self, dim, expanded_index, source) +@_onnx_symbolic("aten::bucketize") @symbolic_helper.parse_args("v", "v", "b", "b") @_beartype.beartype def bucketize(g, self, boundaries, out_int32=False, right=False): @@ -2751,6 +2966,7 @@ def bucketize(g, self, boundaries, out_int32=False, right=False): return symbolic_helper._reducesum_helper(g, cond_out, axes_i=[0], keepdims_i=0) +@_onnx_symbolic("aten::type_as") @_beartype.beartype def type_as(g, self, other): self_dtype = symbolic_helper._try_get_scalar_type(self) @@ -2776,6 +2992,7 @@ def type_as(g, self, other): ) +@_onnx_symbolic("aten::cosine_similarity") @symbolic_helper.parse_args("v", "v", "i", "f") @_beartype.beartype def cosine_similarity(g, x1, x2, dim, eps): @@ -2796,6 +3013,7 @@ def cosine_similarity(g, x1, x2, dim, eps): return div(g, cross, div_tens) +@_onnx_symbolic("aten::pairwise_distance") @_beartype.beartype def pairwise_distance(g, input1, input2, p, eps, keepdim): if not symbolic_helper._is_value(eps): @@ -2814,22 +3032,26 @@ def pairwise_distance(g, input1, input2, p, eps, keepdim): return pow(g, summation, inv_p) +@_onnx_symbolic("aten::clone") # ignore clone operators that are inserted by PyTorch autograd @_beartype.beartype def clone(g, input, unused_memory_format): return input +@_onnx_symbolic("aten::abs") @_beartype.beartype def abs(g, self): return g.op("Abs", self) +@_onnx_symbolic("aten::log") @_beartype.beartype def log(g, self): return g.op("Log", self) +@_onnx_symbolic("aten::log1p") @_beartype.beartype def log1p(g, self): return log( @@ -2837,12 +3059,14 @@ def log1p(g, self): ) +@_onnx_symbolic("aten::log10") @_beartype.beartype def log10(g, self): _ln10 = 2.30258509299404568401 return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor([_ln10]))) +@_onnx_symbolic("aten::pow") @_beartype.beartype def pow(g, self, exponent): f_dtype = self_dtype = self.type().scalarType() @@ -2861,6 +3085,7 @@ def pow(g, self, exponent): return pow +@_onnx_symbolic("aten::clamp") @_beartype.beartype def clamp(g, self, min, max): # min or max may be None that we need to dispatch to @@ -2871,7 +3096,7 @@ def clamp(g, self, min, max): return clamp_min(g, self, min) else: if symbolic_helper._is_constant(min) and symbolic_helper._is_constant(max): - return op_with_optional_float_cast( + return _op_with_optional_float_cast( g, "Clip", self, @@ -2883,11 +3108,12 @@ def clamp(g, self, min, max): return clamp_max(g, clamp_min(g, self, min), max) +@_onnx_symbolic("aten::clamp_min") @symbolic_helper.parse_args("v", "v") @_beartype.beartype def clamp_min(g, self, min): if symbolic_helper._is_constant(min): - return op_with_optional_float_cast( + return _op_with_optional_float_cast( g, "Clip", self, min_f=symbolic_helper._parse_arg(min, "f"), opset_before=12 ) else: @@ -2895,14 +3121,15 @@ def clamp_min(g, self, min): min = g.op( "Cast", min, to_i=_type_utils.JitScalarType.from_name(dtype).onnx_type() ) - return op_with_optional_float_cast(g, "Max", self, min, opset_before=12) + return _op_with_optional_float_cast(g, "Max", self, min, opset_before=12) +@_onnx_symbolic("aten::clamp_max") @symbolic_helper.parse_args("v", "v") @_beartype.beartype def clamp_max(g, self, max): if symbolic_helper._is_constant(max): - return op_with_optional_float_cast( + return _op_with_optional_float_cast( g, "Clip", self, max_f=symbolic_helper._parse_arg(max, "f"), opset_before=12 ) else: @@ -2910,9 +3137,10 @@ def clamp_max(g, self, max): max = g.op( "Cast", max, to_i=_type_utils.JitScalarType.from_name(dtype).onnx_type() ) - return op_with_optional_float_cast(g, "Min", self, max, opset_before=12) + return _op_with_optional_float_cast(g, "Min", self, max, opset_before=12) +@_onnx_symbolic("aten::max") # torch.max (same for torch.min) actually has two interfaces smashed together: # torch.max(x, dim, keepdim) and torch.max(x, y) # TODO(justinchuby): Support multiple quantized args in output @@ -2923,7 +3151,7 @@ def max(g, self, dim_or_y=None, keepdim=None): return g.op("ReduceMax", self, keepdims_i=0) # torch.max(input, other) if keepdim is None: - return op_with_optional_float_cast(g, "Max", self, dim_or_y, opset_before=12) + return _op_with_optional_float_cast(g, "Max", self, dim_or_y, opset_before=12) # torch.max(input, dim, keepdim) else: dim = symbolic_helper._get_const(dim_or_y, "i", "dim") @@ -2933,12 +3161,14 @@ def max(g, self, dim_or_y=None, keepdim=None): return max, indices +@_onnx_symbolic("aten::maximum") @symbolic_helper.quantized_args(True, True) @_beartype.beartype def maximum(g, input, other): return max(g, input, dim_or_y=other) +@_onnx_symbolic("aten::min") # TODO(justinchuby): Support multiple quantized args in output @_beartype.beartype def min(g, self, dim_or_y=None, keepdim=None): @@ -2947,7 +3177,7 @@ def min(g, self, dim_or_y=None, keepdim=None): return g.op("ReduceMin", self, keepdims_i=0) # torch.min(input, other) if keepdim is None: - return op_with_optional_float_cast(g, "Min", self, dim_or_y, opset_before=12) + return _op_with_optional_float_cast(g, "Min", self, dim_or_y, opset_before=12) # torch.min(input, dim, keepdim) else: dim = symbolic_helper._get_const(dim_or_y, "i", "dim") @@ -2957,12 +3187,14 @@ def min(g, self, dim_or_y=None, keepdim=None): return min, indices +@_onnx_symbolic("aten::minimum") @symbolic_helper.quantized_args(True, True) @_beartype.beartype def minimum(g, input, other): return min(g, input, dim_or_y=other) +@_onnx_symbolic("aten::amax") @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", "is", "i") @_beartype.beartype @@ -2970,6 +3202,7 @@ def amax(g, self, dim, keepdim): return g.op("ReduceMax", self, axes_i=dim, keepdims_i=keepdim) +@_onnx_symbolic("aten::amin") @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", "is", "i") @_beartype.beartype @@ -2977,6 +3210,7 @@ def amin(g, self, dim, keepdim): return g.op("ReduceMin", self, axes_i=dim, keepdims_i=keepdim) +@_onnx_symbolic("aten::aminmax") @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", "v", "i") @_beartype.beartype @@ -2991,11 +3225,14 @@ def aminmax(g, self, dim, keepdim): ) +@_onnx_symbolic("aten::exp") @_beartype.beartype def exp(g, self): return g.op("Exp", self) +@_onnx_symbolic("aten::dropout_") +@_onnx_symbolic("aten::dropout") @symbolic_helper.parse_args("v", "f", "i") @_beartype.beartype def dropout(g, input, p, train): @@ -3007,9 +3244,27 @@ def dropout(g, input, p, train): return r +@_onnx_symbolic( + "aten::alpha_dropout_", decorate=[_apply_params("aten::alpha_dropout_")] +) # See Note [Export inplace] +@_onnx_symbolic( + "aten::feature_alpha_dropout_", + decorate=[_apply_params("aten::feature_alpha_dropout_")], +) +@_onnx_symbolic( + "aten::feature_dropout_", decorate=[_apply_params("aten::feature_dropout_")] +) +@_onnx_symbolic( + "aten::feature_alpha_dropout", + decorate=[_apply_params("aten::feature_alpha_dropout")], +) +@_onnx_symbolic("aten::alpha_dropout", decorate=[_apply_params("aten::alpha_dropout")]) +@_onnx_symbolic( + "aten::feature_dropout", decorate=[_apply_params("aten::feature_dropout")] +) @_beartype.beartype -def _unsupported_dropout(name): - @symbolic_helper.parse_args("v", "f", "i") +def _unsupported_dropout(name: str): + @symbolic_helper.parse_args("v", "none", "b") @_beartype.beartype def feature_dropout(g, input, p, train): # NB: In inference mode, FeatureDropout is exported as an identity op. @@ -3020,17 +3275,7 @@ def feature_dropout(g, input, p, train): return feature_dropout -feature_dropout = _unsupported_dropout("feature_dropout") -alpha_dropout = _unsupported_dropout("alpha_dropout") -feature_alpha_dropout = _unsupported_dropout("feature_alpha_dropout") - -# See Note [Export inplace] -dropout_ = dropout -feature_dropout_ = feature_dropout -alpha_dropout_ = alpha_dropout -feature_alpha_dropout_ = feature_alpha_dropout - - +@_onnx_symbolic("aten::norm") @symbolic_helper.parse_args("v", "t", "is", "i") @_beartype.beartype def norm(g, self, p, dim, keepdim): @@ -3045,6 +3290,7 @@ def norm(g, self, p, dim, keepdim): return f(g, self, dim=dim, keepdim=keepdim) +@_onnx_symbolic("aten::conv_tbc") @symbolic_helper.parse_args("v", "v", "v", "i") @_beartype.beartype def conv_tbc(g, input, weight, bias, pad): @@ -3062,6 +3308,7 @@ def conv_tbc(g, input, weight, bias, pad): return g.op("Transpose", conv, perm_i=[2, 0, 1]) +@_onnx_symbolic("aten::_unique") @symbolic_helper.parse_args("v", "i", "i") @_beartype.beartype def _unique(g, input, sorted, return_inverse): @@ -3077,6 +3324,7 @@ def _unique(g, input, sorted, return_inverse): return symbolic_helper._onnx_unsupported("_unique", input) +@_onnx_symbolic("aten::_unique2") @symbolic_helper.parse_args("v", "i", "i", "i") @_beartype.beartype def _unique2(g, input, sorted, return_inverse, return_counts): @@ -3129,20 +3377,24 @@ def _cast_func_template(to_i, g, input, non_blocking): "BFloat16", ): func_name = f"_cast_{scalar_type}" - globals()[func_name] = symbolic_helper.parse_args("v", "i")( - functools.partial( - _cast_func_template, - _type_utils.JitScalarType.from_name(scalar_type).onnx_type(), + globals()[func_name] = _onnx_symbolic(f"aten::{func_name}")( + symbolic_helper.parse_args("v", "i")( + functools.partial( + _cast_func_template, + _type_utils.JitScalarType.from_name(scalar_type).onnx_type(), + ) ) ) +@_onnx_symbolic("aten::empty") @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") @_beartype.beartype def empty(g, sizes, dtype, layout, device, pin_memory=False, memory_format=None): return zeros(g, sizes, dtype, layout, device, pin_memory) +@_onnx_symbolic("aten::empty_like") @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") @_beartype.beartype def empty_like( @@ -3151,6 +3403,7 @@ def empty_like( return zeros_like(g, input, dtype, layout, device, pin_memory) +@_onnx_symbolic("aten::new_empty") @_beartype.beartype def new_empty(g, self, sizes, dtype, layout, device, pin_memory=False): self_dtype = symbolic_helper._try_get_scalar_type(self) @@ -3160,6 +3413,7 @@ def new_empty(g, self, sizes, dtype, layout, device, pin_memory=False): return empty(g, sizes, dtype, layout, device, pin_memory) +@_onnx_symbolic("aten::scalar_tensor") @_beartype.beartype def scalar_tensor(g, scalar, dtype, *options): dtype = symbolic_helper._get_const(dtype, "i", "dtype") @@ -3169,6 +3423,7 @@ def scalar_tensor(g, scalar, dtype, *options): return scalar +@_onnx_symbolic("aten::tensor") @_beartype.beartype def tensor(g, data, dtype=None, device=None, requires_grad=False): dtype = symbolic_helper._get_const(dtype, "i", "dtype") @@ -3195,11 +3450,13 @@ def tensor(g, data, dtype=None, device=None, requires_grad=False): return g.op("Cast", data, to_i=_type_utils.JitScalarType(dtype).onnx_type()) +@_onnx_symbolic("aten::as_tensor") @_beartype.beartype def as_tensor(g, data, dtype=None, device=None): return tensor(g, data, dtype, device) +@_onnx_symbolic("aten::zeros") @symbolic_helper.parse_args("v", "i", "v", "v", "v") @_beartype.beartype def zeros(g, sizes, dtype, layout, device, pin_memory=False): @@ -3218,6 +3475,7 @@ def zeros(g, sizes, dtype, layout, device, pin_memory=False): ) +@_onnx_symbolic("aten::zeros_like") @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") @_beartype.beartype def zeros_like( @@ -3235,6 +3493,7 @@ def zeros_like( ) +@_onnx_symbolic("aten::new_zeros") @_beartype.beartype def new_zeros(g, self, sizes, dtype, layout, device, pin_memory=False): self_dtype = symbolic_helper._try_get_scalar_type(self) @@ -3243,6 +3502,7 @@ def new_zeros(g, self, sizes, dtype, layout, device, pin_memory=False): return zeros(g, sizes, dtype, layout, device, pin_memory) +@_onnx_symbolic("aten::ones") @symbolic_helper.parse_args("v", "i", "v", "v", "v") @_beartype.beartype def ones(g, sizes, dtype, layout, device, pin_memory=False): @@ -3260,6 +3520,7 @@ def ones(g, sizes, dtype, layout, device, pin_memory=False): ) +@_onnx_symbolic("aten::ones_like") @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") @_beartype.beartype def ones_like( @@ -3277,6 +3538,7 @@ def ones_like( ) +@_onnx_symbolic("aten::new_ones") @_beartype.beartype def new_ones(g, self, sizes, dtype, layout, device, pin_memory=False): self_dtype = symbolic_helper._try_get_scalar_type(self) @@ -3286,6 +3548,7 @@ def new_ones(g, self, sizes, dtype, layout, device, pin_memory=False): return ones(g, sizes, dtype, layout, device, pin_memory) +@_onnx_symbolic("aten::full") @_beartype.beartype def full(g, sizes, value, dtype, layout, device, pin_memory=False): const_value = symbolic_helper._maybe_get_const(value, "t") @@ -3309,6 +3572,7 @@ def full(g, sizes, value, dtype, layout, device, pin_memory=False): ) +@_onnx_symbolic("aten::full_like") @_beartype.beartype def full_like( g, @@ -3339,6 +3603,7 @@ def full_like( ) +@_onnx_symbolic("aten::new_full") @_beartype.beartype def new_full(g, self, size, fill_value, dtype, layout, device, pin_memory=False): self_dtype = symbolic_helper._try_get_scalar_type(self) @@ -3348,6 +3613,7 @@ def new_full(g, self, size, fill_value, dtype, layout, device, pin_memory=False) return full(g, size, fill_value, dtype, layout, device, pin_memory) +@_onnx_symbolic("aten::eye") @_beartype.beartype def eye(g, *args): if len(args) == 5: @@ -3372,6 +3638,7 @@ def eye(g, *args): return symbolic_helper._unimplemented("aten::eye", f"with {len(args)} arguments") +@_onnx_symbolic("aten::slice") @_beartype.beartype def slice(g, self, *args): if len(args) == 4: @@ -3437,15 +3704,17 @@ def slice(g, self, *args): return symbolic_helper._unimplemented("aten::slice", f"with {len(args)} arguments") +@_onnx_symbolic("aten::hardtanh") @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", "f", "f") @_beartype.beartype def hardtanh(g, self: _C.Value, min_val: float, max_val: float): - return op_with_optional_float_cast( + return _op_with_optional_float_cast( g, "Clip", self, min_f=min_val, max_f=max_val, opset_before=12 ) +@_onnx_symbolic("aten::hardswish") @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v") @_beartype.beartype @@ -3454,6 +3723,7 @@ def hardswish(g, self): return g.op("Mul", self, hs) +@_onnx_symbolic("aten::hardsigmoid") # Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qhardsigmoid.cpp @symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0) @symbolic_helper.parse_args("v") @@ -3464,12 +3734,14 @@ def hardsigmoid(g, self): return g.op("HardSigmoid", self, alpha_f=1 / 6) +@_onnx_symbolic("aten::tanhshrink") @symbolic_helper.parse_args("v") @_beartype.beartype def tanhshrink(g, self): return g.op("Sub", self, tanh(g, self)) +@_onnx_symbolic("aten::hardshrink") @symbolic_helper.parse_args("v", "f") @_beartype.beartype def hardshrink(g, self, lambd): @@ -3494,6 +3766,7 @@ def hardshrink(g, self, lambd): ) +@_onnx_symbolic("aten::softshrink") @symbolic_helper.parse_args("v", "f") @_beartype.beartype def softshrink(g, self, lambd): @@ -3529,11 +3802,13 @@ def softshrink(g, self, lambd): return add(g, gt_out, lt_out) +@_onnx_symbolic("aten::alias") @_beartype.beartype def alias(g, self): return self +@_onnx_symbolic("aten::unsqueeze") @symbolic_helper.parse_args("v", "i") @_beartype.beartype def unsqueeze(g, self, dim): @@ -3560,6 +3835,7 @@ def unsqueeze(g, self, dim): return symbolic_helper._unsqueeze_helper(g, self, axes_i=[dim]) +@_onnx_symbolic("aten::sort") # TODO(justinchuby): Support multiple quantized args in output @symbolic_helper.parse_args("v", "i", "i", "none") @_beartype.beartype @@ -3582,12 +3858,14 @@ def sort(g, self, dim, decending, out=None): return g.op("TopK", self, k_i=dim_size, axis_i=dim, outputs=2) +@_onnx_symbolic("aten::numel") @_beartype.beartype def numel(g, self): shape = g.op("Shape", self) return g.op("ReduceProd", shape, keepdims_i=0) +@_onnx_symbolic("aten::topk") # TODO(justinchuby): Support multiple quantized args in output @symbolic_helper.parse_args("v", "i", "i", "i", "i", "none") @_beartype.beartype @@ -3602,6 +3880,7 @@ def topk(g, self, k, dim, largest, sorted, out=None): return g.op("TopK", self, k_i=k, axis_i=dim, outputs=2) +@_onnx_symbolic("aten::to") @_beartype.beartype def to(g, self, *args): @_beartype.beartype @@ -3678,6 +3957,7 @@ def is_aten_to_device_only(args): return symbolic_helper._onnx_unsupported("Unknown aten::to signature", self) +@_onnx_symbolic("aten::repeat") @_beartype.beartype def repeat(g, self, repeats): dtype = _type_utils.JitScalarType.INT64 @@ -3686,6 +3966,7 @@ def repeat(g, self, repeats): return g.op("Tile", self, repeats) +@_onnx_symbolic("aten::repeat_interleave") @_beartype.beartype def repeat_interleave(g, self, repeats, dim=None, output_size=None): input = self @@ -3789,6 +4070,7 @@ def repeat_interleave(g, self, repeats, dim=None, output_size=None): return g.op("Concat", *final_splits, axis_i=dim) +@_onnx_symbolic("aten::pixel_shuffle") @symbolic_helper.parse_args("v", "i") @_beartype.beartype def pixel_shuffle(g, self, upscale_factor): @@ -3861,6 +4143,7 @@ def pixel_shuffle(g, self, upscale_factor): ) +@_onnx_symbolic("aten::pixel_unshuffle") @symbolic_helper.parse_args("v", "i") @_beartype.beartype def pixel_unshuffle(g, self, downscale_factor): @@ -4225,6 +4508,7 @@ def _lstm_packed( ) +@_onnx_symbolic("aten::lstm") @_beartype.beartype def lstm(g, *args): if symbolic_helper._is_tensor_list(args[3]): @@ -4233,6 +4517,7 @@ def lstm(g, *args): return _lstm_full(g, *args) +@_onnx_symbolic("aten::lstm_cell") @_beartype.beartype def lstm_cell(g, self, hidden, w_ih, w_hh, b_ih, b_hh): input = symbolic_helper._unsqueeze_helper(g, self, [0]) @@ -4260,8 +4545,14 @@ def lstm_cell(g, self, hidden, w_ih, w_hh, b_ih, b_hh): ), symbolic_helper._squeeze_helper(g, c_outs, [0]) -@_beartype.beartype -def _one_hidden_rnn(kind): +@_onnx_symbolic("aten::gru", decorate=[_apply_params("GRU"), _export("gru")]) +@_onnx_symbolic( + "aten::rnn_tanh", decorate=[_apply_params("RNN_TANH"), _export("rnn_tanh")] +) +@_onnx_symbolic( + "aten::rnn_relu", decorate=[_apply_params("RNN_RELU"), _export("rnn_relu")] +) +def _one_hidden_rnn(kind: str): @symbolic_helper.parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i") @_beartype.beartype def _rnn_full( @@ -4292,7 +4583,6 @@ def _rnn_full( ) @symbolic_helper.parse_args("v", "v", "v", "v", "i", "i", "f", "i", "i") - @_beartype.beartype def _rnn_packed( g, input, @@ -4320,7 +4610,6 @@ def _rnn_packed( batch_sizes=batch_sizes, ) - @_beartype.beartype def symbolic(g, *args): if symbolic_helper._is_tensor_list(args[3]): return _rnn_packed(g, *args) @@ -4330,11 +4619,7 @@ def symbolic(g, *args): return symbolic -gru = _one_hidden_rnn("GRU") -rnn_tanh = _one_hidden_rnn("RNN_TANH") -rnn_relu = _one_hidden_rnn("RNN_RELU") - - +@_onnx_symbolic("aten::_dim_arange") @symbolic_helper.parse_args("v", "i") @_beartype.beartype def _dim_arange(g, like, dim): @@ -4349,12 +4634,14 @@ def _dim_arange(g, like, dim): return arange(g, stop, 4, None, None, None) +@_onnx_symbolic("aten::detach") @_beartype.beartype def detach(g, input): # Erase aten::detach nodes because ONNX is inference only return input +@_onnx_symbolic("aten::contiguous") @symbolic_helper.parse_args("v", "i") @_beartype.beartype def contiguous(g, input, memory_format): @@ -4365,6 +4652,7 @@ def contiguous(g, input, memory_format): return input +@_onnx_symbolic("aten::_pack_padded_sequence") @symbolic_helper.parse_args("v", "v", "i") @_beartype.beartype def _pack_padded_sequence(g, input, lengths, batch_first): @@ -4385,6 +4673,7 @@ def _pack_padded_sequence(g, input, lengths, batch_first): return g.op("prim::PackPadded", input, lengths, outputs=2) +@_onnx_symbolic("aten::_pad_packed_sequence") @symbolic_helper.parse_args("v", "v", "i", "t", "v") @_beartype.beartype def _pad_packed_sequence( @@ -4399,6 +4688,7 @@ def _pad_packed_sequence( return data, lengths +@_onnx_symbolic("aten::randn") @_beartype.beartype def randn(g, shapes, dtype, *options): dtype = symbolic_helper._get_const(dtype, "i", "dtype") @@ -4425,6 +4715,7 @@ def randn(g, shapes, dtype, *options): ) +@_onnx_symbolic("aten::rand") @_beartype.beartype def rand(g, shapes, dtype, *options): dtype = symbolic_helper._get_const(dtype, "i", "dtype") @@ -4451,6 +4742,7 @@ def rand(g, shapes, dtype, *options): ) +@_onnx_symbolic("aten::randn_like") @_beartype.beartype def randn_like( g, self, dtype, layout=None, device=None, pin_memory=False, memory_format=None @@ -4463,6 +4755,7 @@ def randn_like( return g.op("RandomNormalLike", self, dtype_i=scalar_type.onnx_type()) +@_onnx_symbolic("aten::rand_like") @_beartype.beartype def rand_like( g, self, dtype, layout=None, device=None, pin_memory=False, memory_format=None @@ -4475,6 +4768,7 @@ def rand_like( ) +@_onnx_symbolic("aten::rrelu") @symbolic_helper.parse_args("v", "f", "f", "i", "none") @_beartype.beartype def rrelu(g, input, lower, upper, training, generator): @@ -4485,6 +4779,7 @@ def rrelu(g, input, lower, upper, training, generator): return g.op("PRelu", input, p) +@_onnx_symbolic("aten::bernoulli") @_beartype.beartype def bernoulli(g, input, generator=None, out=None): if out is not None: @@ -4514,6 +4809,7 @@ def bernoulli(g, input, generator=None, out=None): ) +@_onnx_symbolic("aten::log_sigmoid") @symbolic_helper.parse_args("v") @_beartype.beartype def log_sigmoid(g, input): @@ -4521,12 +4817,14 @@ def log_sigmoid(g, input): return g.op("Log", p) +@_onnx_symbolic("aten::erf") @symbolic_helper.parse_args("v") @_beartype.beartype def erf(g, input): return g.op("Erf", input) +@_onnx_symbolic("aten::flatten") @symbolic_helper.quantized_args(True, False, False) @symbolic_helper.parse_args("v", "i", "i") @_beartype.beartype @@ -4552,6 +4850,7 @@ def flatten(g, input, start_dim, end_dim): return symbolic_helper._flatten_helper(g, input, start_dim, end_dim, dim) +@_onnx_symbolic("aten::nonzero") @symbolic_helper.parse_args("v") @_beartype.beartype def nonzero(g, input): @@ -4559,12 +4858,14 @@ def nonzero(g, input): return t(g, g.op("NonZero", input)) +@_onnx_symbolic("aten::nonzero_numpy") # Emitted from `torch.nonzero(x, as_tuple=True)` @_beartype.beartype def nonzero_numpy(g, input, _outputs=None): return unbind(g, nonzero(g, input), 1, _outputs=_outputs) +@_onnx_symbolic("aten::isnan") @symbolic_helper.parse_args("v") @_beartype.beartype def isnan(g, input): @@ -4572,6 +4873,7 @@ def isnan(g, input): return output +@_onnx_symbolic("aten::any") @_beartype.beartype def _any(g, *args): # aten::any(Tensor self) @@ -4590,6 +4892,7 @@ def _any(g, *args): return gt(g, input_sum, g.op("Constant", value_t=torch.tensor(0, dtype=torch.long))) +@_onnx_symbolic("aten::all") @_beartype.beartype def _all(g, *args): input = g.op("Not", args[0]) @@ -4601,6 +4904,7 @@ def _all(g, *args): return g.op("Not", _any(g, input, args[1], args[2])) +@_onnx_symbolic("aten::narrow") @symbolic_helper.parse_args("v", "i", "i", "i") @_beartype.beartype def narrow(g, input, dim, start, length): @@ -4609,18 +4913,21 @@ def narrow(g, input, dim, start, length): ) +@_onnx_symbolic("aten::argmax") @symbolic_helper.parse_args("v", "v", "b") @_beartype.beartype def argmax(g, input: torch._C.Value, dim: torch._C.Value, keepdim: bool): return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax") +@_onnx_symbolic("aten::argmin") @symbolic_helper.parse_args("v", "v", "b") @_beartype.beartype def argmin(g, input: torch._C.Value, dim: torch._C.Value, keepdim: bool): return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin") +@_onnx_symbolic("aten::scatter") @symbolic_helper.parse_args("v", "i", "v", "v") @_beartype.beartype def scatter(g, self, dim, index, src): @@ -4642,6 +4949,7 @@ def scatter(g, self, dim, index, src): return g.op("Scatter", self, index, expand_as(g, src, index), axis_i=dim) +@_onnx_symbolic("aten::scatter_add") @symbolic_helper.parse_args("v", "i", "v", "v") @_beartype.beartype def scatter_add(g, self, dim, index, src): @@ -4660,12 +4968,14 @@ def scatter_add(g, self, dim, index, src): return add(g, self, to_add) +@_onnx_symbolic("aten::log2") @_beartype.beartype def log2(g, self): _ln2 = 0.693147180559945309 return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor(_ln2))) +@_onnx_symbolic("aten::is_floating_point") @_beartype.beartype def is_floating_point(g, self): if symbolic_helper._is_fp(self): @@ -4673,6 +4983,7 @@ def is_floating_point(g, self): return g.op("Constant", value_t=torch.BoolTensor([0])) +@_onnx_symbolic("aten::__is_") @_beartype.beartype def __is_(g, self, other): if symbolic_helper._is_none(other): @@ -4682,12 +4993,14 @@ def __is_(g, self, other): return eq(g, self, other) +@_onnx_symbolic("aten::__isnot_") @wrap_logical_op_with_negation @_beartype.beartype def __isnot_(g, self, other): return __is_(g, self, other) +@_onnx_symbolic("aten::one_hot") @_beartype.beartype def one_hot(g, self, num_classes): values = g.op("Constant", value_t=torch.LongTensor([0, 1])) @@ -4697,6 +5010,7 @@ def one_hot(g, self, num_classes): return g.op("OneHot", self, num_classes, values, axis_i=-1) +@_onnx_symbolic("aten::gather") @symbolic_helper.parse_args("v", "i", "v", "v") @_beartype.beartype def gather(g, self, dim, index, sparse_grad=False): @@ -4752,43 +5066,49 @@ def _var_mean(g, input, dim, correction, keepdim): return var, mean +@_onnx_symbolic("aten::std") @_beartype.beartype def std(g, input, *args): var, _ = var_mean(g, input, *args) return g.op("Sqrt", var) +@_onnx_symbolic("aten::var") @_beartype.beartype def var(g, input, *args): var, _ = var_mean(g, input, *args) return var -# var_mean (and all variance-related functions) has multiple signatures, so need to manually figure -# out the correct arguments: -# aten::var_mean(Tensor self, bool unbiased) -# aten::var_mean(Tensor self, int[1] dim, bool unbiased, bool keepdim=False) -# aten::var_mean(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False) +@_onnx_symbolic("aten::var_mean") @_beartype.beartype def var_mean(g, input, *args): + # var_mean (and all variance-related functions) has multiple signatures, so need to manually figure + # out the correct arguments: + # aten::var_mean(Tensor self, bool unbiased) + # aten::var_mean(Tensor self, int[1] dim, bool unbiased, bool keepdim=False) + # aten::var_mean(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False) if len(args) == 1: return _var_mean(g, input, None, args[0], None) else: return _var_mean(g, input, *args) +@_onnx_symbolic("aten::std_mean") @_beartype.beartype def std_mean(g, input, *args): var, mean = var_mean(g, input, *args) return g.op("Sqrt", var), mean +@_onnx_symbolic("aten::logsumexp") @symbolic_helper.parse_args("v", "is", "i") @_beartype.beartype def logsumexp(g, input, dim, keepdim): return g.op("ReduceLogSumExp", input, axes_i=dim, keepdims_i=keepdim) +@_onnx_symbolic("aten::arange") @_beartype.beartype def arange(g, *args): if symbolic_helper.is_caffe2_aten_fallback(): @@ -4871,6 +5191,7 @@ def _float_step_convert(range_tensor): return symbolic_helper._unimplemented("aten::arange", f"with {len(args)} arguments") +@_onnx_symbolic("aten::linspace") @_beartype.beartype def linspace(g, start, end, steps, dtype, layout, device, pin_memory): range_tensor = symbolic_helper._arange_helper(g, steps, None) @@ -4882,12 +5203,14 @@ def linspace(g, start, end, steps, dtype, layout, device, pin_memory): return add(g, mul(g, range_tensor, step), start) +@_onnx_symbolic("aten::lift") @_beartype.beartype def lift(g, self): # at::lift() is a no-op from the perspective of tracing for onnx return self +@_onnx_symbolic("aten::masked_fill") @_beartype.beartype def masked_fill(g, self, mask, value): mask = _cast_Bool(g, mask, False) # type: ignore[name-defined] @@ -4895,6 +5218,7 @@ def masked_fill(g, self, mask, value): return g.op("Where", mask, symbolic_helper._if_scalar_type_as(g, value, self), self) +@_onnx_symbolic("aten::index") @_beartype.beartype def index(g, self, index): if symbolic_helper.is_caffe2_aten_fallback(): @@ -5061,6 +5385,7 @@ def try_mask_to_index(index): return symbolic_helper._reshape_helper(g, self, final_shape) +@_onnx_symbolic("aten::linalg_norm") @symbolic_helper.parse_args("v", "v", "is", "b", "v") @_beartype.beartype def linalg_norm( @@ -5096,6 +5421,7 @@ def linalg_norm( return linalg_matrix_norm(g, self, ord, dim, keepdim, dtype) +@_onnx_symbolic("aten::linalg_vector_norm") @symbolic_helper.parse_args("v", "f", "is", "b", "v") @_beartype.beartype def linalg_vector_norm( @@ -5136,6 +5462,7 @@ def linalg_vector_norm( return result +@_onnx_symbolic("aten::linalg_matrix_norm") @symbolic_helper.parse_args("v", "v", "is", "b", "v") @_beartype.beartype def linalg_matrix_norm( @@ -5197,12 +5524,14 @@ def linalg_matrix_norm( return result +@_onnx_symbolic("aten::linalg_cross") @symbolic_helper.parse_args("v", "v", "i") @_beartype.beartype def linalg_cross(g, input, other, dim=-1): return cross(g, input, other, dim) +@_onnx_symbolic("aten::frobenius_norm") @symbolic_helper.parse_args("v", "is", "b") @_beartype.beartype def frobenius_norm(g, self, dim=None, keepdim=False): @@ -5211,6 +5540,7 @@ def frobenius_norm(g, self, dim=None, keepdim=False): return g.op("Sqrt", sumsqr) +@_onnx_symbolic("aten::multinomial") @symbolic_helper.parse_args("v", "i", "b", "v") @_beartype.beartype def multinomial(g, input, num_samples, replacement=False, generator=None): @@ -5234,6 +5564,7 @@ def multinomial(g, input, num_samples, replacement=False, generator=None): ) +@_onnx_symbolic("aten::baddbmm") @_beartype.beartype def baddbmm(g, self, batch1, batch2, beta, alpha): dtype = self.type().scalarType() @@ -5253,6 +5584,7 @@ def baddbmm(g, self, batch1, batch2, beta, alpha): return add(g, mul_a, mul_b) +@_onnx_symbolic("aten::meshgrid") @symbolic_helper.parse_args("v", "s") @_beartype.beartype def meshgrid(g, tensor_list, indexing: Optional[str] = None): @@ -5285,6 +5617,7 @@ def meshgrid(g, tensor_list, indexing: Optional[str] = None): return g.op("prim::ListConstruct", *out) +@_onnx_symbolic("aten::remainder") @_beartype.beartype def remainder(g, input, other): div = _floor_divide(g, input, other) @@ -5292,6 +5625,7 @@ def remainder(g, input, other): return g.op("Sub", input, quo) +@_onnx_symbolic("aten::gelu") @symbolic_helper.parse_args("v", "s") @_beartype.beartype def gelu(g, self: torch._C.Value, approximate: str = "none"): @@ -5320,6 +5654,7 @@ def gelu(g, self: torch._C.Value, approximate: str = "none"): ) +@_onnx_symbolic("aten::group_norm") @symbolic_helper.quantized_args(True, False, False, False) @symbolic_helper.parse_args("v", "i", "v", "v", "f", "i") @_beartype.beartype @@ -5400,6 +5735,7 @@ def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled): ) +@_onnx_symbolic("aten::_weight_norm") @symbolic_helper.parse_args("v", "v", "i") @_beartype.beartype def _weight_norm(g, weight_v, weight_g, dim): @@ -5428,6 +5764,7 @@ def _weight_norm(g, weight_v, weight_g, dim): ) +@_onnx_symbolic("aten::dim") @_beartype.beartype def dim(g, self): """Implement the dim functionality available for a pytorch tensor in ONNX""" @@ -5436,16 +5773,19 @@ def dim(g, self): return g.op("Size", shape) +@_onnx_symbolic("aten::__getitem_") @_beartype.beartype def __getitem_(g, self, i): return select(g, self, g.op("Constant", value_t=torch.tensor([0])), i) +@_onnx_symbolic("aten::item") @_beartype.beartype def item(g, self): return self +@_onnx_symbolic("aten::take") @_beartype.beartype def take(g, self, index): self_flattened = symbolic_helper._reshape_helper( @@ -5475,6 +5815,7 @@ def _kl_div_non_log_target_impl(g, input, target): return output +@_onnx_symbolic("aten::kl_div") @symbolic_helper.parse_args("v", "v", "i", "b") @_beartype.beartype def kl_div(g, input, target, reduction, log_target): @@ -5495,6 +5836,7 @@ def kl_div(g, input, target, reduction, log_target): ) +@_onnx_symbolic("aten::as_strided") @symbolic_helper.quantized_args(True) @symbolic_helper.parse_args("v", "v", "is", "i") @_beartype.beartype @@ -5542,11 +5884,13 @@ def as_strided(g, self, sizes, strides, offset=None): return g.op("Gather", self_1d, ind) +@_onnx_symbolic("aten::__derive_index") @_beartype.beartype def __derive_index(g, index, start, step): return g.op("Add", start, g.op("Mul", index, step)) +@_onnx_symbolic("aten::__range_length") # Source code for aten op can be found here: pytorch/torch/csrc/jit/runtime/register_prim_ops.cpp # if (step > 0 && lo < hi) { # push(stack, 1 + (hi - 1 - lo) / step); @@ -5562,6 +5906,7 @@ def __range_length(g, lo, hi, step): return g.op("Cast", div, to_i=_C_onnx.TensorProtoDataType.INT64) +@_onnx_symbolic("aten::linear") @_beartype.beartype def linear(g, input, weight, bias): rank = symbolic_helper._get_tensor_rank(input) @@ -5578,6 +5923,7 @@ def linear(g, input, weight, bias): return output +@_onnx_symbolic("aten::hann_window") @symbolic_helper.parse_args("v", "b", "i", "v", "v", "v", "v") @_beartype.beartype def hann_window( @@ -5618,16 +5964,19 @@ def hann_window( return output +@_onnx_symbolic("aten::mv") @_beartype.beartype def mv(g, self, vec): return matmul(g, self, vec) +@_onnx_symbolic("aten::dot") @_beartype.beartype def dot(g, self, other): return matmul(g, self, other) +@_onnx_symbolic("aten::movedim") @symbolic_helper.parse_args("v", "t", "t") @_beartype.beartype def movedim(g, self, source, destination): @@ -5662,6 +6011,7 @@ def movedim(g, self, source, destination): return g.op("Transpose", self, perm_i=perm) +@_onnx_symbolic("aten::fill") @symbolic_helper.parse_args("v", "v") @_beartype.beartype def fill(g, self, value): @@ -5674,6 +6024,7 @@ def fill(g, self, value): return full_like(g, self, value, dtype) +@_onnx_symbolic("aten::index_add") @_beartype.beartype def index_add(g, self, dim, index, other, alpha=None): warnings.warn( @@ -5744,6 +6095,7 @@ def index_add(g, self, dim, index, other, alpha=None): return scatter_add(g, self, dim, expand_as(g, index, other), other) +@_onnx_symbolic("aten::roll") @symbolic_helper.parse_args("v", "is", "is") @_beartype.beartype def roll(g, self, shifts, dims): @@ -5765,6 +6117,7 @@ def roll(g, self, shifts, dims): return result +@_onnx_symbolic("aten::cross") @symbolic_helper.parse_args("v", "v", "i") @_beartype.beartype def cross(g, input, other, dim=None): @@ -5784,6 +6137,7 @@ def cross(g, input, other, dim=None): return sub(g, mul(g, roll_x_1, roll_y_1), mul(g, roll_x_2, roll_y_2)) +@_onnx_symbolic("aten::cdist") @_beartype.beartype def cdist(g, x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"): # X1.shape = (B * P * D), X2.shape = (B * R * D) @@ -5801,6 +6155,7 @@ def cdist(g, x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"): ) +@_onnx_symbolic("aten::lerp") @_beartype.beartype def lerp(g, self, end, weight): # Conditional for better numeric. This has been discussed in @@ -5822,6 +6177,7 @@ def lerp(g, self, end, weight): ) +@_onnx_symbolic("aten::broadcast_tensors") @_beartype.beartype def broadcast_tensors(g, self): all_tensors = symbolic_helper._unpack_list(self) @@ -5836,323 +6192,330 @@ def broadcast_tensors(g, self): return g.op("prim::ListConstruct", *t_list) +@_onnx_symbolic("aten::is_pinned") def is_pinned(g, self, device=None): # Unused by ONNX. return None -class Prim: - domain = "prim" +@_onnx_symbolic("prim::ConstantSplit") +@_beartype.beartype +def prim_constant_split(g, self, split_size, dim): + size = symbolic_helper._get_tensor_dim_size(self, dim) + if size is None: + return symbolic_helper._unimplemented( + "prim::ConstantSplit", "unknown dimension size", self + ) + splits = [split_size] * (size // split_size) + leftover = size % split_size + if leftover: + splits.append(leftover) + return g.op("Split", self, split_i=splits, axis_i=dim, outputs=len(splits)) - @staticmethod - @_beartype.beartype - def ConstantSplit(g, self, split_size, dim): - size = symbolic_helper._get_tensor_dim_size(self, dim) - if size is None: - return symbolic_helper._unimplemented( - "prim::ConstantSplit", "unknown dimension size", self - ) - splits = [split_size] * (size // split_size) - leftover = size % split_size - if leftover: - splits.append(leftover) - return g.op("Split", self, split_i=splits, axis_i=dim, outputs=len(splits)) - - # TODO: It would be better to export this as a chunk directly, as this is - # less sensitive to changes in input size. - # TODO: Once we have proper scoping, stop reimplementing chunk, delete this - # method, and use the desugared version - @staticmethod - @_beartype.beartype - def ConstantChunk(g, self, chunks, dim): - dim_size = symbolic_helper._get_tensor_dim_size(self, dim) - if dim_size is None: - return symbolic_helper._unimplemented( - "prim::ConstantChunk", "unknown dimension size", self - ) - split_size = (dim_size + chunks - 1) // chunks - return Prim.ConstantSplit(g, self, split_size, dim) - @staticmethod - @_beartype.beartype - def shape(g, self): - return g.op("Shape", self) +# TODO: It would be better to export this as a chunk directly, as this is +# less sensitive to changes in input size. +# TODO: Once we have proper scoping, stop reimplementing chunk, delete this +# method, and use the desugared version +@_onnx_symbolic("prim::ConstantChunk") +@_beartype.beartype +def prim_constant_chunk(g, self, chunks, dim): + dim_size = symbolic_helper._get_tensor_dim_size(self, dim) + if dim_size is None: + return symbolic_helper._unimplemented( + "prim::ConstantChunk", "unknown dimension size", self + ) + split_size = (dim_size + chunks - 1) // chunks + return prim_constant_split(g, self, split_size, dim) - @staticmethod - @_beartype.beartype - def max(g, self, other): - return op_with_optional_float_cast(g, "Max", self, other, opset_before=12) - @staticmethod - @_beartype.beartype - def min(g, self, other=None): - if not other: - if symbolic_helper._is_packed_list(self): - self = stack(g, self, g.op("Constant", value_t=torch.tensor([0]))) - return min(g, self) - return min(g, self, other) - - @staticmethod - @_beartype.beartype - def data(g, self): - return self +@_onnx_symbolic("prim::shape") +@_beartype.beartype +def prim_shape(g, self): + return g.op("Shape", self) - @staticmethod - def layout(g, self): - # Unused by ONNX. - return None - @staticmethod - @_beartype.beartype - def ListConstruct(g, *inputs, **kwargs): - return None +@_onnx_symbolic("prim::max") +@_beartype.beartype +def prim_max(g, self, other): + return _op_with_optional_float_cast(g, "Max", self, other, opset_before=12) - @staticmethod - @_beartype.beartype - def ListUnpack(g, *inputs, **kwargs) -> Optional[List[_C.Value]]: - if len(inputs) == 1 and inputs[0].node().kind() == "prim::ListConstruct": - # Cancel the previous node if it is ListConstruct by returning its inputs - # TODO(justinchuby): Use a public method in the helper module - return symbolic_helper._unpack_list(inputs[0]) - return None +@_onnx_symbolic("prim::min") +@_beartype.beartype +def prim_min(g, self, other=None): + if not other: + if symbolic_helper._is_packed_list(self): + self = stack(g, self, g.op("Constant", value_t=torch.tensor([0]))) + return min(g, self) + return min(g, self, other) - @staticmethod - @_beartype.beartype - def TupleConstruct(g, *inputs, **kwargs): - return None - @staticmethod - @_beartype.beartype - def Uninitialized(g, *inputs, **kwargs): - return None +@_onnx_symbolic("prim::data") +@_beartype.beartype +def prim_data(g, self): + return self - # exists to refine the type of the Value - # if x is an optional Tensor, unchecked_cast will cast - # x to Tensor, so the rest of the graph knows that x is a Tensor - # this doesn't do anything in runtime and is a noop in ONNX - @staticmethod - @_beartype.beartype - def unchecked_cast(g, self): - return self - @staticmethod - @_beartype.beartype - def dtype(g, self): - scalar_name = symbolic_helper._try_get_scalar_type(self) - if scalar_name is None: - scalar_name = "Float" - scalar_type = _type_utils.JitScalarType.from_name(scalar_name) - # This node records a torch dtype as int - return g.op("Constant", value_t=torch.tensor(scalar_type)) - - @staticmethod - @_beartype.beartype - def tolist(g, input, dim_val, elem_ty_val): - """tolist is currently supported only for 1D input tensors. - - dim_val and elem_ty_val represent dimension and type annotations - that need to match dimension and type of the input tensor. - """ - dim = symbolic_helper._maybe_get_const(dim_val, "i") - if dim > 1: - return symbolic_helper._unimplemented("prim::tolist", "dim_val > 1", input) - return input +@_onnx_symbolic("prim::layout") +def prim_layout(g, self): + # Unused by ONNX. + return None - # ----------------------------------------------------------------------------- - # Symbolic functions that need extra context - # ----------------------------------------------------------------------------- - @staticmethod - @_beartype.beartype - def device(ctx: SymbolicContext, g: _C.Graph, *inputs, **kwargs) -> None: - output_type = ctx.cur_node.output().type() - if isinstance(output_type, _C.DeviceObjType): - return None - return symbolic_helper._unimplemented( - "prim::device", - f"output type should be 'DeviceObjType', not '{output_type.kind()}'", - ctx.cur_node.output(), - ) +@_onnx_symbolic("prim::ListConstruct") +@_beartype.beartype +def prim_list_construct(g, *inputs, **kwargs): + return None + + +@_onnx_symbolic("prim::ListUnpack") +@_beartype.beartype +def prim_list_unpack(g, *inputs, **kwargs) -> Optional[List[_C.Value]]: + if len(inputs) == 1 and inputs[0].node().kind() == "prim::ListConstruct": + # Cancel the previous node if it is ListConstruct by returning its inputs + # TODO(justinchuby): Use a public method in the helper module + return symbolic_helper._unpack_list(inputs[0]) + + return None + + +@_onnx_symbolic("prim::TupleConstruct") +@_beartype.beartype +def prim_tuple_construct(g, *inputs, **kwargs): + return None + + +@_onnx_symbolic("prim::Uninitialized") +@_beartype.beartype +def prim_uninitialized(g, *inputs, **kwargs): + return None - @staticmethod - @_beartype.beartype - def Loop(ctx: SymbolicContext, g, *inputs, **attrs): - n = ctx.cur_node - env = ctx.env - params_dict = ctx.params_dict - operator_export_type = GLOBALS.operator_export_type - opset_version = GLOBALS.export_onnx_opset_version +# exists to refine the type of the Value +# if x is an optional Tensor, unchecked_cast will cast +# x to Tensor, so the rest of the graph knows that x is a Tensor +# this doesn't do anything in runtime and is a noop in ONNX +@_onnx_symbolic("prim::unchecked_cast") +@_beartype.beartype +def prim_unchecked_cast(g, self): + return self + + +@_onnx_symbolic("prim::dtype") +@_beartype.beartype +def prim_dtype(g, self): + scalar_name = symbolic_helper._try_get_scalar_type(self) + if scalar_name is None: + scalar_name = "Float" + scalar_type = _type_utils.JitScalarType.from_name(scalar_name) + # This node records a torch dtype as int + return g.op("Constant", value_t=torch.tensor(scalar_type)) + + +@_onnx_symbolic("prim::tolist") +@_beartype.beartype +def prim_tolist(g, input, dim_val, elem_ty_val): + """tolist is currently supported only for 1D input tensors. + + dim_val and elem_ty_val represent dimension and type annotations + that need to match dimension and type of the input tensor. + """ + dim = symbolic_helper._maybe_get_const(dim_val, "i") + if dim > 1: + return symbolic_helper._unimplemented("prim::tolist", "dim_val > 1", input) + return input + + +# ----------------------------------------------------------------------------- +# Symbolic functions that need extra context +# ----------------------------------------------------------------------------- +@_onnx_symbolic("prim::device") +@_beartype.beartype +def prim_device(ctx: SymbolicContext, g: _C.Graph, *inputs, **kwargs) -> None: + output_type = ctx.cur_node.output().type() + if isinstance(output_type, _C.DeviceObjType): + return None + + return symbolic_helper._unimplemented( + "prim::device", + f"output type should be 'DeviceObjType', not '{output_type.kind()}'", + ctx.cur_node.output(), + ) - new_op_outputs = g.op("Loop", *inputs, outputs=n.outputsSize()) + +@_onnx_symbolic("prim::Loop") +@_beartype.beartype +def prim_loop(ctx: SymbolicContext, g, *inputs, **attrs): + n = ctx.cur_node + env = ctx.env + params_dict = ctx.params_dict + + operator_export_type = GLOBALS.operator_export_type + opset_version = GLOBALS.export_onnx_opset_version + + new_op_outputs = g.op("Loop", *inputs, outputs=n.outputsSize()) + new_node = ( + new_op_outputs[0].node() if n.outputsSize() > 1 else new_op_outputs.node() + ) + for b in n.blocks(): + new_block = new_node.addBlock() + # Copy input metadata to subblock + # + # prim::Loop(iter, cond, input_1, ..., input_n) + # block0(iter, input_1, ..., input_n) + # + # For `Loop` node, copy metadata for `iter`, `input_1`, ..., `input_n`. + for i, b_in in enumerate(b.inputs()): + if i == 0 and i < len(inputs): + b_in.setType(inputs[i].type()) + # For optional block inputs, they may switch between None not-None inside + # the loop body, so if the loop input is not optional, the block input may + # still need to be optional. + if ( + i > 0 + and (i + 1) < len(inputs) + and not isinstance(b_in.type(), _C.OptionalType) + ): + b_in.setType(inputs[i + 1].type()) + torch._C._jit_pass_onnx_block( + b, new_block, operator_export_type, env, False # type:ignore[arg-type] + ) + new_op_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node( + new_node, opset_version + ) + # Run shape type inference for Loop after subblock is converted. + if GLOBALS.onnx_shape_inference: + torch._C._jit_pass_onnx_node_shape_type_inference( + new_node, params_dict, opset_version + ) + return new_op_outputs + + +@_onnx_symbolic("prim::If") +@_beartype.beartype +def prim_if(ctx: SymbolicContext, g, *inputs, **attrs): + n = ctx.cur_node + block = ctx.onnx_block + env = ctx.env + params_dict = ctx.params_dict + + operator_export_type = GLOBALS.operator_export_type + opset_version = GLOBALS.export_onnx_opset_version + + static_if = inputs[0].node().kind() == "onnx::Constant" + if static_if: + # Fold static if + # + # The torch IR + # graph(%embedding_matrix.1 : Float(10, 15, strides=[15, 1], requires_grad=0, device=cpu), + # %input.1 : Long(6, strides=[1], requires_grad=0, device=cpu), ... + # %65 : Bool(requires_grad=0, device=cpu) = prim::Constant[value={0}]() + # %21 : Long(device=cpu) = aten::eq(%20, %64) + # %22 : Long(device=cpu) = prim::If(%21) + # block0(): + # %23 : Long(device=cpu) = aten::is_floating_point(%input.1) + # -> (%23) + # block1(): + # -> (%65) + # %input.53 : Tensor, %weight : Tensor = prim::If(%22) + # block0(): + # -> (%embedding_matrix.1, %input.1) + # block1(): + # -> (%input.1, %embedding_matrix.1) + # %26 : int[] = aten::size(%input.53) + # + # The converted ONNX graph + # %10 : Bool(device=cpu) = onnx::Constant[value={0}]() + # %14 : Bool(device=cpu) = onnx::Equal(%13, %8) + # %15 : Bool(requires_grad=0, device=cpu) = onnx::Constant[value={0}]() + # %16 : Long(1, strides=[1], device=cpu) = onnx::Shape(%input.1) + input_flag = symbolic_helper._node_get(inputs[0].node(), "value").tolist() + const_value = ( + all(input_flag) if isinstance(input_flag, list) else bool(input_flag) + ) + block_idx = 0 if const_value else 1 + current_b = list(n.blocks())[block_idx] + env = torch._C._jit_pass_onnx_block( + current_b, + block, + operator_export_type, # type:ignore[arg-type] + env, # type:ignore[arg-type] + True, + ) + if_output_list = list(n.outputs()) + current_b_list = list(current_b.outputs()) + + final_b_list = [] + for idx in range(len(if_output_list)): + if current_b_list[idx] not in env: + raise errors.SymbolicValueError( + f"The sub block ATen output {current_b_list[idx]} is not in env.", + current_b_list[idx], + ) # type:ignore[operator] + onnx_b = env[current_b_list[idx]] + final_b_list.append(onnx_b) + return final_b_list + else: + new_op_outputs = g.op("If", *inputs, outputs=n.outputsSize()) new_node = ( new_op_outputs[0].node() if n.outputsSize() > 1 else new_op_outputs.node() ) for b in n.blocks(): new_block = new_node.addBlock() - # Copy input metadata to subblock - # - # prim::Loop(iter, cond, input_1, ..., input_n) - # block0(iter, input_1, ..., input_n) - # - # For `Loop` node, copy metadata for `iter`, `input_1`, ..., `input_n`. - for i, b_in in enumerate(b.inputs()): - if i == 0 and i < len(inputs): - b_in.setType(inputs[i].type()) - # For optional block inputs, they may switch between None not-None inside - # the loop body, so if the loop input is not optional, the block input may - # still need to be optional. - if ( - i > 0 - and (i + 1) < len(inputs) - and not isinstance(b_in.type(), _C.OptionalType) - ): - b_in.setType(inputs[i + 1].type()) torch._C._jit_pass_onnx_block( - b, new_block, operator_export_type, env, False # type:ignore[arg-type] + b, + new_block, + operator_export_type, # type:ignore[arg-type] + env, + False, ) new_op_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node( new_node, opset_version ) - # Run shape type inference for Loop after subblock is converted. + # Run shape type inference for If after subblock is converted. if GLOBALS.onnx_shape_inference: torch._C._jit_pass_onnx_node_shape_type_inference( new_node, params_dict, opset_version ) return new_op_outputs - @staticmethod - @_beartype.beartype - def If(ctx: SymbolicContext, g, *inputs, **attrs): - n = ctx.cur_node - block = ctx.onnx_block - env = ctx.env - params_dict = ctx.params_dict - - operator_export_type = GLOBALS.operator_export_type - opset_version = GLOBALS.export_onnx_opset_version - - static_if = inputs[0].node().kind() == "onnx::Constant" - if static_if: - # Fold static if - # - # The torch IR - # graph(%embedding_matrix.1 : Float(10, 15, strides=[15, 1], requires_grad=0, device=cpu), - # %input.1 : Long(6, strides=[1], requires_grad=0, device=cpu), ... - # %65 : Bool(requires_grad=0, device=cpu) = prim::Constant[value={0}]() - # %21 : Long(device=cpu) = aten::eq(%20, %64) - # %22 : Long(device=cpu) = prim::If(%21) - # block0(): - # %23 : Long(device=cpu) = aten::is_floating_point(%input.1) - # -> (%23) - # block1(): - # -> (%65) - # %input.53 : Tensor, %weight : Tensor = prim::If(%22) - # block0(): - # -> (%embedding_matrix.1, %input.1) - # block1(): - # -> (%input.1, %embedding_matrix.1) - # %26 : int[] = aten::size(%input.53) - # - # The converted ONNX graph - # %10 : Bool(device=cpu) = onnx::Constant[value={0}]() - # %14 : Bool(device=cpu) = onnx::Equal(%13, %8) - # %15 : Bool(requires_grad=0, device=cpu) = onnx::Constant[value={0}]() - # %16 : Long(1, strides=[1], device=cpu) = onnx::Shape(%input.1) - input_flag = symbolic_helper._node_get(inputs[0].node(), "value").tolist() - const_value = ( - all(input_flag) if isinstance(input_flag, list) else bool(input_flag) - ) - block_idx = 0 if const_value else 1 - current_b = list(n.blocks())[block_idx] - env = torch._C._jit_pass_onnx_block( - current_b, - block, - operator_export_type, # type:ignore[arg-type] - env, # type:ignore[arg-type] - True, - ) - if_output_list = list(n.outputs()) - current_b_list = list(current_b.outputs()) - - final_b_list = [] - for idx in range(len(if_output_list)): - if current_b_list[idx] not in env: - raise errors.SymbolicValueError( - f"The sub block ATen output {current_b_list[idx]} is not in env.", - current_b_list[idx], - ) # type:ignore[operator] - onnx_b = env[current_b_list[idx]] - final_b_list.append(onnx_b) - return final_b_list - else: - new_op_outputs = g.op("If", *inputs, outputs=n.outputsSize()) - new_node = ( - new_op_outputs[0].node() - if n.outputsSize() > 1 - else new_op_outputs.node() - ) - for b in n.blocks(): - new_block = new_node.addBlock() - torch._C._jit_pass_onnx_block( - b, - new_block, - operator_export_type, # type:ignore[arg-type] - env, - False, - ) - new_op_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node( - new_node, opset_version - ) - # Run shape type inference for If after subblock is converted. - if GLOBALS.onnx_shape_inference: - torch._C._jit_pass_onnx_node_shape_type_inference( - new_node, params_dict, opset_version - ) - return new_op_outputs - @staticmethod - @_beartype.beartype - def Constant(ctx: SymbolicContext, g, *inputs, **attrs): - n = ctx.cur_node - - if n.mustBeNone(): - return None - # This must go before checking for string values, because some device constants - # have string values, but we want to keep them as unconverted Device types so - # that eq() can work on them. - if isinstance(n.output().type(), _C.DeviceObjType): - return None - if n.kindOf("value") == "t": - return g.op("Constant", value_t=symbolic_helper._node_get(n, "value")) - if n.kindOf("value") == "s": - return g.op("Constant", value_s=symbolic_helper._node_get(n, "value")) - if n.output().type().isSubtypeOf( - _C.ListType.ofInts() - ) or n.output().type().isSubtypeOf(_C.ListType.ofFloats()): - return g.op( - "Constant", value_t=torch.tensor(symbolic_helper._node_get(n, "value")) - ) +@_onnx_symbolic("prim::Constant") +@_beartype.beartype +def prim_constant(ctx: SymbolicContext, g, *inputs, **attrs): + n = ctx.cur_node - raise errors.SymbolicValueError( - f"Unsupported prim::Constant kind: `{n.kindOf('value')}`. " - f"Please send a bug report at {_constants.PYTORCH_GITHUB_ISSUES_URL}.", - n.output(), + if n.mustBeNone(): + return None + # This must go before checking for string values, because some device constants + # have string values, but we want to keep them as unconverted Device types so + # that eq() can work on them. + if isinstance(n.output().type(), _C.DeviceObjType): + return None + if n.kindOf("value") == "t": + return g.op("Constant", value_t=symbolic_helper._node_get(n, "value")) + if n.kindOf("value") == "s": + return g.op("Constant", value_s=symbolic_helper._node_get(n, "value")) + if n.output().type().isSubtypeOf( + _C.ListType.ofInts() + ) or n.output().type().isSubtypeOf(_C.ListType.ofFloats()): + return g.op( + "Constant", value_t=torch.tensor(symbolic_helper._node_get(n, "value")) ) + raise errors.SymbolicValueError( + f"Unsupported prim::Constant kind: `{n.kindOf('value')}`. " + f"Please send a bug report at {_constants.PYTORCH_GITHUB_ISSUES_URL}.", + n.output(), + ) -class Onnx: - domain = "onnx" - # ----------------------------------------------------------------------------- - # Symbolic functions that need extra context - # ----------------------------------------------------------------------------- - @staticmethod - @_beartype.beartype - def Placeholder(ctx: SymbolicContext, g, *inputs, **attrs): - n = ctx.cur_node - block = ctx.onnx_block - env = ctx.env +@_onnx_symbolic("onnx::Placeholder") +@_beartype.beartype +def onnx_placeholder(ctx: SymbolicContext, g, *inputs, **attrs): + n = ctx.cur_node + block = ctx.onnx_block + env = ctx.env - return torch._C._jit_onnx_convert_pattern_from_subblock(block, n, env) + return torch._C._jit_onnx_convert_pattern_from_subblock(block, n, env)