Skip to content

Commit

Permalink
[ONNX] Use decorators for symbolic function registration (pytorch#84448)
Browse files Browse the repository at this point in the history
This is the 4th PR in the series of pytorch#83787. It enables the use of `@onnx_symbolic` across `torch.onnx`.

- **Backward breaking**: Removed some symbolic functions from `__all__` because of the use of  `@onnx_symbolic` for registering the same function on multiple aten names.
- Decorate all symbolic functions with `@onnx_symbolic`
- Move Quantized and Prim ops out from classes to functions defined in the modules. Eliminate the need for `isfunction` checking, speeding up the registration process by 60%.
    - Remove the outdated unit test `test_symbolic_opset9.py`
- Symbolic function registration moved from the first call to `_run_symbolic_function` to init time.
- Registration is fast:
  ![image](https://user-images.githubusercontent.com/11205048/189164959-f3fca173-19bc-4682-b150-f13a586387bf.png)

Pull Request resolved: pytorch#84448
Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao
  • Loading branch information
justinchuby authored and pytorchmergebot committed Sep 22, 2022
1 parent c7c2578 commit 76d6077
Show file tree
Hide file tree
Showing 16 changed files with 1,533 additions and 969 deletions.
32 changes: 0 additions & 32 deletions test/onnx/symbolic_opsets/test_symbolic_opset9.py

This file was deleted.

18 changes: 18 additions & 0 deletions torch/onnx/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,21 @@ Torch->ONNX converter / exporter.
[User-facing docs](https://pytorch.org/docs/master/onnx.html).

[Developer docs](https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter).

## Symbolic functions Opsets

Opset 9 is the base version. It is selected as the base version because

1. It is the first opset version supported by PyTorch export.
2. Opset 9 is more robust than previous opset versions. Opset versions like 7/8 have limitations
that certain basic operators cannot be expressed in ONNX. Instead of basing on these limitations,
we chose to handle them as special cases separately.

Backward support for opset versions beyond opset 7 is not in our roadmap.

For opset versions other than 9, by default they will inherit the symbolic functions defined in
symbolic_opset9.py.

To extend support for updated operators in different opset versions on top of opset 9,
simply add the updated symbolic functions in the respective symbolic_opset{version}.py file.
Checkout topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.py for example.
6 changes: 2 additions & 4 deletions torch/onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""ONNX exporter."""
import warnings

from torch import _C
from torch._C import _onnx as _C_onnx
Expand All @@ -26,6 +25,7 @@
symbolic_opset14,
symbolic_opset15,
symbolic_opset16,
symbolic_opset17,
utils,
)
from ._exporter_states import ExportTypes, SymbolicContext
Expand Down Expand Up @@ -60,6 +60,7 @@
"symbolic_opset14",
"symbolic_opset15",
"symbolic_opset16",
"symbolic_opset17",
# Enums
"ExportTypes",
"OperatorExportTypes",
Expand Down Expand Up @@ -133,6 +134,3 @@ def log(*args) -> None:
character appended to the end, and flushed to output stream.
"""
_C._jit_onnx_log(*args)


_registration.discover_and_register_all_symbolic_opsets()
45 changes: 0 additions & 45 deletions torch/onnx/_internal/registration.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Module for handling symbolic function registration."""

import importlib
import inspect
import warnings
from typing import (
Callable,
Expand Down Expand Up @@ -265,49 +263,6 @@ def all_functions(self) -> Set[str]:
return set(self._registry)


def discover_and_register_all_symbolic_opsets() -> None:
"""Discover all symbolic functions.
Opset 9 is the base version. It is selected as the base version because
1. It is the first opset version supported by PyTorch export.
2. opset 9 is more robust than previous opset versions. Opset versions like 7/8 have limitations
that certain basic operators cannot be expressed in ONNX. Instead of basing on these limitations,
we chose to handle them as special cases separately.
Backward support for opset versions beyond opset 7 is not in our roadmap.
For opset versions other than 9, by default they will inherit the symbolic functions defined in
symbolic_opset9.py.
To extend support for updated operators in different opset versions on top of opset 9,
simply add the updated symbolic functions in the respective symbolic_opset{version}.py file.
Checkout topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.py for example.
"""
for opset in range(_constants.ONNX_MIN_OPSET, _constants.ONNX_MAX_OPSET + 1):
module = importlib.import_module(f"torch.onnx.symbolic_opset{opset}")
_register_module(module, opset)


def _register_module(module, opset: OpsetVersion) -> None:
"""Registers all functions in the given module.
Args:
module: The module to register.
opset: The opset version to register.
"""
global registry
members = inspect.getmembers(module)
for name, obj in members:
if isinstance(obj, type) and hasattr(obj, "domain"):
# Symbolic functions in domains other than aten
ops = inspect.getmembers(obj, predicate=inspect.isfunction)
for op in ops:
registry.register(f"{obj.domain}::{op[0]}", opset, op[1]) # type: ignore[attr-defined]

elif inspect.isfunction(obj):
if name in {"_len", "_list", "_any", "_all"}:
name = name[1:]
registry.register(f"aten::{name}", opset, obj)


@_beartype.beartype
def onnx_symbolic(
name: str,
Expand Down
6 changes: 3 additions & 3 deletions torch/onnx/symbolic_caffe2.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def upsample_nearest2d(
g, input, output_size, align_corners=None, scales_h=None, scales_w=None
):
if input not in symbolic_helper._quantized_ops:
return opset9.upsample_nearest2d(g, input, output_size, align_corners)
return opset9.upsample_nearest2d(g, input, output_size, align_corners) # type: ignore[attr-defined]

output_size = symbolic_helper._parse_arg(output_size, "is")
kwargs = {
Expand All @@ -194,7 +194,7 @@ def upsample_nearest2d(
@symbolic_helper.parse_args("v", "is", "is", "is", "is", "i")
def max_pool2d(g, input, kernel_size, stride, padding, dilation, ceil_mode):
if input not in symbolic_helper._quantized_ops:
return opset9.max_pool2d(
return opset9.max_pool2d( # type: ignore[attr-defined]
g, input, kernel_size, stride, padding, dilation, ceil_mode
)
kwargs = {
Expand Down Expand Up @@ -224,7 +224,7 @@ def avg_pool2d(
divisor_override=None,
):
if input not in symbolic_helper._quantized_ops:
return opset9.avg_pool2d(
return opset9.avg_pool2d( # type: ignore[attr-defined]
g,
input,
kernel_size,
Expand Down
Loading

0 comments on commit 76d6077

Please sign in to comment.