Skip to content

Commit

Permalink
[ONNX] Add primitives formatting for diagnostics (pytorch#105889)
Browse files Browse the repository at this point in the history
E.g., `<type: int>` -> `2`.
Pull Request resolved: pytorch#105889
Approved by: https://github.com/thiagocrepaldi, https://github.com/titaiwangms
  • Loading branch information
BowenBao authored and pytorchmergebot committed Jul 25, 2023
1 parent 00c6a2e commit 8282c53
Showing 1 changed file with 24 additions and 1 deletion.
25 changes: 24 additions & 1 deletion torch/onnx/_internal/fx/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.diagnostics.infra import decorator, formatter, utils

from torch.onnx._internal.fx import type_utils as fx_type_utils
from torch.onnx._internal.fx import registration, type_utils as fx_type_utils

# NOTE: Symbolic shapes could be a calculation of values, such as
# Tensor(i64[s0, 64, (s1//2) - 2, (s1//2) - 2]) where s0 and s1 are symbolic
Expand Down Expand Up @@ -104,6 +104,29 @@ def _torch_tensor(obj: torch.Tensor) -> str:
return f"Tensor({fx_type_utils.from_torch_dtype_to_abbr(obj.dtype)}{_stringify_shape(obj.shape)})"


@_format_argument.register
def _int(obj: int) -> str:
return str(obj)


@_format_argument.register
def _float(obj: float) -> str:
return str(obj)


@_format_argument.register
def _bool(obj: bool) -> str:
return str(obj)


@_format_argument.register
def _registration_symbolic_function(obj: registration.SymbolicFunction) -> str:
# TODO: Compact display of `param_schema`.
return (
f"registration.SymbolicFunction({obj.op_full_name}, is_custom={obj.is_custom})"
)


@_format_argument.register
def _list(obj: list) -> str:
list_string = f"List[length={len(obj)}](\n"
Expand Down

0 comments on commit 8282c53

Please sign in to comment.