Skip to content

Commit

Permalink
[ONNX] Fix the warnings of aten overload fallback to default in onn…
Browse files Browse the repository at this point in the history
…x dispatcher (pytorch#105972)

Without this PR, the warning message is misleading as it says the default is found before the error message popped.
Next PR will start refactoring aten overload fallback with adding overloads supported by torchlib into OpSchema matching.
Pull Request resolved: pytorch#105972
Approved by: https://github.com/BowenBao
  • Loading branch information
titaiwangms authored and pytorchmergebot committed Jul 26, 2023
1 parent 8d9c889 commit 1544291
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions torch/onnx/_internal/fx/onnxfunction_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,25 +304,26 @@ def get_function_overloads(
)

# NOTE: Fall back to default overload if the ONNX registry doesn't have the overload.
# TODO: Should we have a better fallback mechanism?
if function_group is None:
function_group = self.onnx_registry.get_functions(
namespace=internal_opname.namespace,
op_name=internal_opname.op_name,
overload=None,
)

# NOTE: Currently, most of torchlib functions are not registered with overload
# in ONNX registry. So we will only log a warning in SARIF if we can't find the overload
# to avoid spammy warnings in printout.
# TODO: https://github.com/microsoft/onnxscript/issues/828
op_full_name = internal_opname.qualified_name()
diagnostic = diagnostic_context.inflight_diagnostic()
diagnostic.with_additional_message(
"### The operator overload is not found in onnx registry!\n"
"Cannot find the operator overload in onnx registry, but"
"the default overload is found. Please check the ONNX output carefully. \n",
)
diagnostic.level = diagnostics.levels.WARNING
if function_group is not None:
# NOTE: Currently, most of torchlib functions are not registered with overload
# in ONNX registry. So we will only log a warning in SARIF if we can't find the overload
# to avoid spammy warnings in printout.
# TODO: https://github.com/microsoft/onnxscript/issues/828
op_full_name = internal_opname.qualified_name()
diagnostic = diagnostic_context.inflight_diagnostic()
diagnostic.with_additional_message(
"### The operator overload is not found in onnx registry!\n"
"Cannot find the operator overload in onnx registry, but "
"the default overload is found. Please check the ONNX output carefully. \n",
)
diagnostic.level = diagnostics.levels.WARNING

# NOTE: If the ATen/Custom operators are not registered, the group will be None.
if function_group is not None:
Expand Down

0 comments on commit 1544291

Please sign in to comment.