Skip to content

Commit

Permalink
Fixes a missing attribute in function implementation of CastLike (onn…
Browse files Browse the repository at this point in the history
…x#5246)

### Description
The function implementation of CastLike does not propagate the attribute
*saturate*. This PR fixes it and make the changes in the reference
implementation to check the function implementation is working. This PR
fixes two issues:

1. function implementation of CastLike (19)
2. reference implementation working context dependent function

By default, the reference implementation choose a kernel if there is
one, even if the operator defines an implementation. The PR introduces a
mechanism to force the use of the kernel and expand the operator by its
function. It delays that when the function is context dependant. It is
then done when the node is executed and its inputs known.

---------

Signed-off-by: xadupre <[email protected]>
  • Loading branch information
xadupre authored Jun 19, 2023
1 parent 925840b commit e7e8aa7
Show file tree
Hide file tree
Showing 23 changed files with 180 additions and 25 deletions.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
4 changes: 3 additions & 1 deletion onnx/defs/tensor/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,9 @@ ONNX_OPERATOR_SET_SCHEMA(
}
auto target_elt_type = target_type->tensor_type().elem_type();
FunctionBuilder builder(functionProto);
builder.Add("output = Cast (input)", "to", (int64_t)(target_elt_type));
builder.Add(
MakeString("output = Cast <to= ", (int64_t)(target_elt_type), ", saturate: int = @saturate> (input)")
.c_str());
schema.BuildFunction(functionProto);
return true;
}));
Expand Down
87 changes: 78 additions & 9 deletions onnx/reference/op_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from onnx import TensorProto
from onnx.defs import get_all_schemas_with_history, get_schema, onnx_opset_version
from onnx.helper import make_node
from onnx.helper import make_node, make_tensor_type_proto, np_dtype_to_tensor_dtype
from onnx.numpy_helper import to_array
from onnx.onnx_pb import AttributeProto, GraphProto, NodeProto, TypeProto
from onnx.reference.custom_element_types import (
Expand Down Expand Up @@ -588,12 +588,36 @@ def eval(
return res


class OpRunExpand(OpRun):
"""
Class any operator to avoid must inherit from.
"""

def __init__(
self, onnx_node: NodeProto, log_function: Any, impl: Any = None
): # pylint: disable=super-init-not-called
raise RuntimeError(
f"The reference implementation must not use this node ({type(self)})."
)

def _run(self, *inputs, **kwargs):
raise RuntimeError(
f"The reference implementation must not use this node ({type(self)})."
)


class OpFunction(OpRun):
"""
Runs a custom function.
"""

def __init__(self, onnx_node: NodeProto, log_function: Any, impl: Any = None):
def __init__(
self,
onnx_node: NodeProto,
log_function: Any,
impl: Any = None,
attributes: Optional[Dict[str, Any]] = None,
):
if impl is None:
raise RuntimeError(
f"impl cannot be None for node type {onnx_node.op_type!r} "
Expand All @@ -604,24 +628,69 @@ def __init__(self, onnx_node: NodeProto, log_function: Any, impl: Any = None):
# The function implementation is the same whenever the function is called
# but the attributes may be different at every call.
self.attributes_ = {
name: getattr(self, name) for name in self.impl_.attributes_
name: getattr(self, name)
for name in getattr(self.impl_, "attributes_", attributes) # type: ignore[union-attr]
}

def _run(self, *inputs, **kwargs): # type: ignore # pylint: disable=W0221
if len(self.impl_.input_names) != len(inputs):
return self._run_impl(self.impl_, *inputs, **kwargs)

def _run_impl(self, impl, *inputs, **kwargs): # type: ignore # pylint: disable=W0221
if len(impl.input_names) != len(inputs):
raise RuntimeError(
f"Mismatch lengths between the number of inputs {len(inputs)} "
f"and the expected number of inputs {len(self.impl_.inputs)} "
f"and the expected number of inputs {len(impl.inputs)} "
f"for node {self.op_type!r} from domain {self.domain!r}."
)
feeds = dict(zip(self.impl_.input_names, inputs))
feeds = dict(zip(impl.input_names, inputs))
attributes = self.attributes_.copy()
attributes.update(kwargs)
results = self.impl_.run(None, feeds, attributes=attributes)
if len(self.impl_.output_names) != len(results):
results = impl.run(None, feeds, attributes=attributes)
if len(impl.output_names) != len(results):
raise RuntimeError(
f"Mismatch lengths between the number of outputs {len(results)} "
f"and the expected number of outputs {len(self.impl_.output_names)} "
f"and the expected number of outputs {len(impl.output_names)} "
f"for node {self.op_type!r} from domain {self.domain!r}."
)
return tuple(results)


class OpFunctionContextDependant(OpFunction):
"""
The function can be instantiated but only at execution time.
An instance of OpFunction is created everytime to node is executed.
This is needed when the schema of an operator defines a context dependant function.
"""

def __init__(self, onnx_node: NodeProto, log_function: Any, parent: Any = None):
OpFunction.__init__(self, onnx_node, log_function, impl=self, attributes={})
self.parent = parent
version = parent.opsets[onnx_node.domain]
self.schema_ = get_schema(onnx_node.op_type, version, onnx_node.domain)

def _run(self, *inputs, **kwargs):
# Input types are known. They are used to properly
# created the body for this operator.
types = []
for t in inputs:
try:
ttype = np_dtype_to_tensor_dtype(t.dtype)
except KeyError as e:
if t.dtype == float8e4m3fn:
ttype = TensorProto.FLOAT8E4M3FN # type: ignore[attr-defined]
elif t.dtype == float8e4m3fnuz:
ttype = TensorProto.FLOAT8E4M3FNUZ # type: ignore[attr-defined]
elif t.dtype == float8e5m2:
ttype = TensorProto.FLOAT8E5M2 # type: ignore[attr-defined]
elif t.dtype == float8e5m2fnuz:
ttype = TensorProto.FLOAT8E5M2FNUZ # type: ignore[attr-defined]
elif t.dtype == bfloat16:
ttype = TensorProto.BLOFAT16 # type: ignore[attr-defined]
else:
raise e
types.append(make_tensor_type_proto(ttype, t.shape))
cl = self.parent._load_impl( # pylint: disable=protected-access
self.onnx_node, types
)
inst = cl(self.onnx_node, self.run_params)
return self._run_impl(inst.impl_, *inputs, **kwargs)
7 changes: 5 additions & 2 deletions onnx/reference/ops/_op_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def load_op(
custom: Any = None,
node: Union[None, NodeProto] = None,
input_types: Union[None, List[TypeProto]] = None,
expand: bool = False,
) -> Any:
"""
Loads the implemented for a specified operator.
Expand All @@ -252,6 +253,8 @@ def load_op(
which is context dependant
:param input_types: used if no implementation was found and the operator defines a function
which is context dependant
:param expand: use the function implemented in the schema instead
of its reference implementation
:return: class
"""
global _registered_operators
Expand All @@ -264,7 +267,7 @@ def load_op(
version = onnx_opset_version()
if domain != "":
raise ValueError(f"Domain must be '' not {domain!r}.")
if op_type in _registered_operators: # type: ignore
if op_type in _registered_operators and not expand: # type: ignore
found = True
else:
# maybe the operator can be replacted by a function
Expand All @@ -286,7 +289,7 @@ def load_op(
raise RuntimeContextError(
f"No registered implementation for operator {op_type!r} "
f"and domain {domain!r}, the operator has a context dependent function. "
f"but argument node or input_types is not defined."
f"but argument node or input_types is not defined (input_types={input_types})."
)
from onnx.reference import ReferenceEvaluator

Expand Down
68 changes: 56 additions & 12 deletions onnx/reference/reference_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0
# pylint: disable=C0415,R0902,R0912,R0913,R0914,R0915
# pylint: disable=C3001,C0415,R0902,R0912,R0913,R0914,R0915
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple, Union

Expand All @@ -10,7 +10,13 @@
from onnx import load
from onnx.defs import onnx_opset_version
from onnx.onnx_pb import FunctionProto, GraphProto, ModelProto, NodeProto, TypeProto
from onnx.reference.op_run import OpRun, RuntimeContextError, to_array_extended
from onnx.reference.op_run import (
OpFunctionContextDependant,
OpRun,
OpRunExpand,
RuntimeContextError,
to_array_extended,
)
from onnx.reference.ops_optimized import optimized_operators


Expand Down Expand Up @@ -164,6 +170,24 @@ def _run(self, ...):
`Pad_18` is selected for any greater opset. Both classes must be
imported into file `_op_list.py` to register their existence to the
runtime.
An operator may have a reference implementation such as `CastLike`
and still be defined as a function. By default, the reference implementation
is used. This behaviour can be changed by adding a class to the list
of overwritten operators. It must inherit from :class:`OpRunExpand`.
::
from onnx.reference.op_run import OpRunExpand
class CastLike(OpRunExpand):
op_domain = ""
ref = ReferenceEvaluator(model, new_ops=[CastLike])
# ...
This mechanism is used in unit test to check the function
implementation a schema may define.
"""

def __init__( # type: ignore
Expand Down Expand Up @@ -257,12 +281,12 @@ def __init__( # type: ignore
self.new_ops_: Dict[Tuple[str, str], OpRun] = {}
if new_ops is not None:
for cl in new_ops:
if not issubclass(cl, OpRun): # type: ignore
raise TypeError(f"Class {cl} must inherit from OpRun (in new_ops).")
if not hasattr(cl, "op_domain"):
raise AttributeError(
f"Class {cl} must define attribute 'op_domain'."
)
if not issubclass(cl, OpRun): # type: ignore
raise TypeError(f"Class {cl} must inherit from OpRun (in new_ops).")
key = cl.op_domain, cl.__name__ # type: ignore
if key in self.new_ops_:
# Already an implementation, the first one is used.
Expand Down Expand Up @@ -326,15 +350,17 @@ def has_linked_attribute(self):
def __str__(self) -> str:
return f"{self.__class__.__name__}({', '.join(self.input_names)}) -> {', '.join(self.output_names)}"

def get_result_types(self, name: str) -> Any:
def get_result_types(self, name: str, exc: bool = True) -> Any:
if self.all_types_ is None:
raise RuntimeError(
f"Unable to return type for name {name!r}. Run shape_inference first."
)
if name not in self.all_types_:
raise RuntimeError(
f"Unable to return type for name {name!r}, it was not found in {sorted(self.all_types_)}."
)
if exc:
raise RuntimeError(
f"Unable to return type for name {name!r}, it was not found in {sorted(self.all_types_)}."
)
return None
return self.all_types_[name]

def _init(self) -> None:
Expand Down Expand Up @@ -367,8 +393,14 @@ def _init(self) -> None:
# A node has a context dependent implementation.
# Shape inference must be run to get the input types.
if self.all_types_:
it = [self.get_result_types(i) for i in node.input]
cl = self._load_impl(node, it) # type: ignore
it = [self.get_result_types(i, exc=False) for i in node.input]
if None in it:
# One input does not exist. It must be done while executing the graph.
cl = lambda *args, parent=self: OpFunctionContextDependant( # noqa: E731
*args, parent=parent
)
else:
cl = self._load_impl(node, it) # type: ignore
else:
raise RuntimeContextError(
f"No implementation was found for node type {node.op_type!r} from domain {node.domain!r}. "
Expand Down Expand Up @@ -397,17 +429,22 @@ def _load_impl(
)
version = self.opsets[node.domain]
key = node.domain, node.op_type
expand = False
if key in self.new_ops_:
# This operator has a custom implementation.
# This mechanism can be used to implement a custom onnx node
# or to overwrite an existing one.
return self.new_ops_[key]
cl = self.new_ops_[key]
if not issubclass(cl, OpRunExpand):
return cl
# It must be replaced by its implementation defined in its schema.
expand = True

if node.domain == "":
from onnx.reference.ops import load_op

try:
return load_op(node.domain, node.op_type, version)
return load_op(node.domain, node.op_type, version, expand=expand)
except RuntimeContextError:
if input_types is None:
raise
Expand All @@ -417,8 +454,15 @@ def _load_impl(
version,
node=node,
input_types=input_types, # type: ignore[arg-type]
expand=expand,
)

if expand:
raise NotImplementedError(
f"Expanding an operator with its function definition "
f"is only implemented for the main opset. Remove operator "
f"{node.domain},{node.op_type} from the list of inlined operator."
)
if node.domain == "ai.onnx.preview.training":
from onnx.reference.ops.aionnx_preview_training import load_op as load_op_pt

Expand Down
39 changes: 38 additions & 1 deletion onnx/test/reference_evaluator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
)
from onnx.numpy_helper import float8e4m3_to_float32, float8e5m2_to_float32, from_array
from onnx.reference import ReferenceEvaluator
from onnx.reference.op_run import OpRun
from onnx.reference.op_run import OpRun, OpRunExpand
from onnx.reference.ops import load_op
from onnx.reference.ops._op_common_indices import _get_indices, _is_out
from onnx.reference.ops._op_list import Celu
Expand Down Expand Up @@ -3165,6 +3165,43 @@ def test_cast_float8(self):
assert_allclose(got[2], expected1)
assert_allclose(got[3], expected2)

def test_cast_like_float8(self):
X = make_tensor_value_info("X", TensorProto.FLOAT, [None])
Y = make_tensor_value_info("Y", TensorProto.FLOAT, [None])
model = make_model(
make_graph(
[
make_node("Cast", ["X"], ["f8"], to=TensorProto.FLOAT8E4M3FNUZ),
make_node("CastLike", ["X", "f8"], ["f32"], saturate=0),
make_node("Cast", ["f32"], ["Y"], to=TensorProto.FLOAT),
],
"g",
[X],
[Y],
)
)
data = np.array([0, 1e7], dtype=np.float32)
expected = np.array(
[
float8e4m3_to_float32(
float32_to_float8e4m3(x, uz=True, saturate=False), uz=True
)
for x in data
]
)
ref = ReferenceEvaluator(model)
got = ref.run(None, {"X": data})
assert_allclose(got[0], expected)

# Forces ReferenceEvaluator to not use the associated implementation for CastLike
# but its implementation as a function instead.
class CastLike(OpRunExpand):
op_domain = ""

ref = ReferenceEvaluator(model, new_ops=[CastLike])
got = ref.run(None, {"X": data})
assert_allclose(got[0], expected)

def test_cast_float8_output(self):
X = make_tensor_value_info("X", TensorProto.FLOAT, [None])
F1 = make_tensor_value_info("F1", TensorProto.FLOAT8E4M3FN, [None])
Expand Down

0 comments on commit e7e8aa7

Please sign in to comment.