Skip to content

Commit

Permalink
[ONNX] Clean up patch functions (pytorch#83136)
Browse files Browse the repository at this point in the history
Changes:

- Move namespace handling from `_new_node` to `_graph_op` for clarity
- Always require the `aten` namespace when creating aten ops. Remove the `aten` argument supplied in `_aten_op` for clarity
- Rename the `_ATTR_PATTERN` global
- Improve types
- Update `_add_attribute` to raise ValueErrors
Pull Request resolved: pytorch#83136
Approved by: https://github.com/BowenBao
  • Loading branch information
justinchuby authored and pytorchmergebot committed Aug 26, 2022
1 parent ec5b83f commit 681c387
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 44 deletions.
2 changes: 1 addition & 1 deletion torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ class Block:
def returnNode(self) -> Node: ...
def owningNode(self) -> Node: ...
def registerOutput(self, n: Value) -> _int: ...
def addNode(self, name: str, values: Sequence[Value]) -> Node: ...
def addNode(self, name: str, inputs: Sequence[Value]) -> Node: ...
...

# Defined in torch/csrc/jit/ir/ir.h
Expand Down
102 changes: 59 additions & 43 deletions torch/onnx/_patch_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,22 @@
import torch
from torch import _C
from torch._C import _onnx as _C_onnx
from torch.onnx import _deprecation

# Import utils to get _params_dict because it is a global that is accessed by c++ code
from torch.onnx import _deprecation, utils
from torch.onnx._globals import GLOBALS

_ATTR_PATTERN = re.compile("^(.+)_(([ifstgz])|(ty))$")


# TODO(#78694): Refactor the patching process to make it more transparent to users.
def _graph_op(
g: torch._C.Graph,
g: _C.Graph,
opname: str,
*raw_args: torch._C.Value,
*raw_args: _C.Value,
outputs: int = 1,
**kwargs,
) -> Union[torch._C.Value, Tuple[torch._C.Value, ...]]:
) -> Union[_C.Value, Tuple[_C.Value, ...]]:
r"""Creates an ONNX operator "opname", taking "args" as inputs and attributes "kwargs".
The set of operators and the inputs/attributes they take
Expand All @@ -27,7 +31,8 @@ def _graph_op(
Args:
g: The Torch graph.
opname: The ONNX operator name, e.g., `Abs` or `Add`. TODO(justinchu): Update examples to correct ones.
opname: The ONNX operator name, e.g., `Abs` or `Add`, or an operator qualified
with a namespace, e.g., `aten::add`.
raw_args: The inputs to the operator; usually provided
as arguments to the `symbolic` definition.
outputs: The number of outputs this operator returns.
Expand All @@ -51,22 +56,18 @@ def _graph_op(
# now they can pass through None attributes, and have them not show up
kwargs = {k: v for k, v in kwargs.items() if v is not None}

def const_if_tensor(arg):
if arg is None:
return arg
elif isinstance(arg, torch._C.Value):
return arg
else:
return g.op("Constant", value_z=arg) # type: ignore[attr-defined]
args = [_const_if_tensor(g, arg) for arg in raw_args]

args = [const_if_tensor(arg) for arg in raw_args]
n = g.insertNode(_new_node(g, opname, outputs, *args, **kwargs)) # type: ignore[attr-defined]
if "::" in opname:
namespace, op = opname.split("::")
else:
namespace = "onnx"
op = opname

# Import utils to get _params_dict because it is a global that is accessed by c++ code
from torch.onnx import utils
n = g.insertNode(_new_node(g, namespace, op, outputs, *args, **kwargs))

if GLOBALS.onnx_shape_inference:
torch._C._jit_pass_onnx_node_shape_type_inference(
_C._jit_pass_onnx_node_shape_type_inference(
n, utils._params_dict, GLOBALS.export_onnx_opset_version
)

Expand All @@ -75,11 +76,23 @@ def const_if_tensor(arg):
return tuple(n.outputs())


def _const_if_tensor(g: _C.Graph, arg):
if arg is None:
return arg
if isinstance(arg, _C.Value):
return arg
return _graph_op(g, "Constant", value_z=arg)


# Generate an ONNX ATen op node.
def _aten_op(g, operator: str, *args, overload_name: str = "", **kwargs):
kwargs["aten"] = True
return g.op(
"ATen", *args, operator_s=operator, overload_name_s=overload_name, **kwargs
def _aten_op(g: _C.Graph, operator: str, *args, overload_name: str = "", **kwargs):
return _graph_op(
g,
"aten::ATen",
*args,
operator_s=operator,
overload_name_s=overload_name,
**kwargs,
)


Expand All @@ -91,35 +104,38 @@ def _block_op(b: _C.Block, opname: str, *args, **kwargs):
aten = kwargs.pop("aten", False)
ns = "aten" if aten else "onnx"
ns_opname = ns + "::" + opname
n = b.addNode(ns_opname, list(args))
n = b.addNode(ns_opname, args)
for k, v in sorted(kwargs.items()):
# TODO: enable inplace in aten exporting mode.
if k == "inplace":
continue
_add_attribute(n, k, v, aten=aten)
if len(list(n.outputs())) == 1:
outputs = tuple(n.outputs())
if len(outputs) == 1:
return n.output()
return tuple(o for o in n.outputs())
return outputs


def _new_node(g: torch._C.Graph, opname: str, outputs, *args, **kwargs):
if "::" in opname:
aten = False
ns_opname = opname
else:
aten = kwargs.pop("aten", False)
ns = "aten" if aten else "onnx"
ns_opname = ns + "::" + opname
n = g.create(ns_opname, args, outputs) # type: ignore[attr-defined]
def _new_node(
g: _C.Graph, namespace: str, op: str, outputs: int, *args, **kwargs
) -> _C.Node:
"""Creates a new node in the graph.
Args:
g: The graph to create the operator on.
namespace: The namespace of the operator. E.g., "aten", "onnx".
op: The name of the operator to create.
outputs: The number of the outputs of the node.
Returns:
The new node.
"""
aten = kwargs.pop("aten", False)
node = g.create(f"{namespace}::{op}", args, outputs)
for k, v in sorted(kwargs.items()):
# TODO: enable inplace in aten exporting mode.
if k == "inplace":
continue
_add_attribute(n, k, v, aten=aten)
return n


_attr_pattern = re.compile("^(.+)_(([ifstgz])|(ty))$")
_add_attribute(node, k, v, aten=aten)
return node


def _is_onnx_list(value):
Expand All @@ -145,9 +161,9 @@ def _is_caffe2_aten_fallback():

def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool):
r"""Initializes the right attribute based on type of value."""
m = _attr_pattern.match(key)
m = _ATTR_PATTERN.match(key)
if m is None:
raise IndexError(
raise ValueError(
f"Invalid attribute specifier '{key}' names "
" must be suffixed with type, e.g. 'dim_i' or 'dims_i'"
)
Expand All @@ -165,7 +181,7 @@ def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool):
kind = "f"
else:
kind = "i"
return getattr(node, kind + "_")(name, value)
return getattr(node, f"{kind}_")(name, value)


# TODO(#76254): Remove the deprecated function.
Expand Down

0 comments on commit 681c387

Please sign in to comment.