Skip to content

Commit

Permalink
[ONNX] Add quantization support to more single output ops (pytorch#83008
Browse files Browse the repository at this point in the history
)

pytorch#80039

- Implement quantization support for single output ops
  - quantized::sigmoid
  - quantized::instance_norm
  - aten::reshape
  - aten::reshape_as
  - aten::sum
  - aten::mean
  - aten::prod
  - aten::t
  - aten::numpy_T
  - aten::expand
  - aten::expand_as
  - aten::embedding
  - aten::embedding_bag
  - aten::view
  - aten::select
  - aten::eq
  - aten::ne
  - aten::gt
  - aten::lt
  - aten::le
  - aten::ge
  - quantized::layer_norm
  - aten::elu
  - aten::selu
  - aten::maximum
  - aten::minimum
  - aten::amax
  - aten::amin
  - aten::hardtanh
  - aten::hardswish
  - quantized::group_norm
  - aten::as_strided
  - quantized::leaky_relu
  - aten::transpose
- Avoid modifying functions in `quantized_args` and have the wrapper closed over `scale` and `zero_point` instead (for purity)
- Remove magic number and assign it to INT64_MAX
- implement `_unpack_quantized_tensor` for handling quantized tensor unpacking to separate the logic from tuple unpacking and for clearer error handling
Pull Request resolved: pytorch#83008
Approved by: https://github.com/BowenBao
  • Loading branch information
justinchuby authored and pytorchmergebot committed Aug 23, 2022
1 parent 1e4383f commit 80cfafc
Show file tree
Hide file tree
Showing 5 changed files with 315 additions and 78 deletions.
119 changes: 103 additions & 16 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -11778,26 +11778,113 @@ def test_quantized_conv2d_relu(self):
q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
self.run_test(model, q_input)

@common_utils.parametrize(
"function_or_module",
[
common_utils.subtest(
torch.nn.ReLU(),
name="relu",
),
common_utils.subtest(
torch.nn.LeakyReLU(),
name="leaky_relu",
),
common_utils.subtest(
torch.nn.quantized.LeakyReLU(2.0, 1),
name="quantized_leaky_relu",
),
common_utils.subtest(
torch.nn.quantized.Hardswish(2.0, 1),
name="quantized_hardswish",
),
common_utils.subtest(
torch.nn.Sigmoid(),
name="sigmoid",
),
common_utils.subtest(
torch.nn.quantized.Sigmoid(2.0, 1),
name="quantized_sigmoid",
),
common_utils.subtest(
torch.nn.Hardsigmoid(),
name="hardsigmoid",
),
common_utils.subtest(
torch.nn.Tanh(),
name="tanh",
),
common_utils.subtest(
torch.nn.Hardtanh(),
name="hardtanh",
),
common_utils.subtest(
lambda x: torch.transpose(x, 0, 1),
name="transpose",
),
common_utils.subtest(
lambda x: x.expand(2, 4, 2, 3),
name="expand",
),
common_utils.subtest(
lambda x: x.view(1, 4, 6),
name="view",
),
common_utils.subtest(
lambda x: x.select(1, 1),
name="select",
),
common_utils.subtest(
torch.nn.quantized.LayerNorm(
[4, 2, 3],
torch.nn.Parameter(torch.ones([4, 2, 3])),
torch.nn.Parameter(torch.zeros([4, 2, 3])),
2.0,
1,
),
name="layer_norm",
),
common_utils.subtest(
torch.nn.quantized.InstanceNorm1d(
2,
torch.nn.Parameter(torch.ones(4)),
torch.nn.Parameter(torch.zeros(4)),
2.0,
1,
),
name="instance_norm",
),
common_utils.subtest(
torch.nn.quantized.GroupNorm(
2,
4,
torch.nn.Parameter(torch.zeros(4)),
torch.nn.Parameter(torch.zeros(4)),
2.0,
1,
),
name="group_norm",
),
common_utils.subtest(
lambda x: torch.as_strided(x, (2, 2), (1, 2)),
name="as_strided",
),
],
)
@skipScriptTest()
@skipIfUnsupportedMinOpsetVersion(10)
def test_quantized_hardswish(self):
model = torch.nn.quantized.Hardswish(1.0, 0)
input = torch.randn(2, 6)
def test_quantized_unary_ops(self, function_or_module):
input = torch.randn(1, 4, 2, 3)
q_input = torch.quantize_per_tensor(input, 0.26, 128, torch.quint8)
self.run_test(model, q_input)

@skipIfUnsupportedMinOpsetVersion(10)
def test_quantized_hardsigmoid(self):
model = torch.nn.Hardsigmoid()
input = torch.randn(2, 6)
q_input = torch.quantize_per_tensor(input, 0.26, 128, torch.quint8)
self.run_test(model, q_input)
class Model(torch.nn.Module):
def __init__(self, function_or_module):
super().__init__()
self.function_or_module = function_or_module

@skipIfUnsupportedMinOpsetVersion(10)
def test_quantized_sigmoid(self):
model = torch.nn.Sigmoid()
input = torch.randn(2, 6)
q_input = torch.quantize_per_tensor(input, 0.26, 128, torch.quint8)
self.run_test(model, q_input)
def forward(self, x):
return self.function_or_module(x)

self.run_test(Model(function_or_module), q_input)

@skipIfUnsupportedMinOpsetVersion(10)
def test_quantized_flatten(self):
Expand Down
95 changes: 69 additions & 26 deletions torch/onnx/symbolic_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch import _C

# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics
from torch.onnx import _patch_torch, _type_utils, errors # noqa: F401
from torch.onnx import _constants, _patch_torch, _type_utils, errors # noqa: F401
from torch.onnx._globals import GLOBALS

# Note [Edit Symbolic Files]
Expand Down Expand Up @@ -212,7 +212,7 @@ def _unpack_list(list_value: _C.Value) -> List[_C.Value]:

def _unpack_tuple(tuple_value: _C.Value) -> Tuple[_C.Value, ...]:
tuple_node = tuple_value.node()
if tuple_node.kind() != "prim::TupleConstruct":
if not _is_tuple_construct(tuple_value):
raise errors.SymbolicValueError(
f"ONNX symbolic expected node type 'prim::TupleConstruct', "
f"got '{tuple_node.kind()}'.",
Expand All @@ -221,6 +221,27 @@ def _unpack_tuple(tuple_value: _C.Value) -> Tuple[_C.Value, ...]:
return tuple(tuple_node.inputs())


def _unpack_quantized_tensor(tuple_value: _C.Value) -> Tuple[_C.Value, ...]:
"""Unpacks a quantized tensor into a tuple of tensor and scale/zero_point.
Args:
tuple_value: A tuple of tensor, scale, zero_point, and optionally axis.
Returns:
A tuple of tensor, scale, zero_point, and optionally axis.
"""
tuple_node = tuple_value.node()
# A quantized tensor is represented as tuple of the form (tensor, scale, zero_point, <axis>)
if not _is_tuple_construct(tuple_value):
raise errors.SymbolicValueError(
f"ONNX symbolic expected the output of `{tuple_node}` to be a quantized "
f"tensor. Is this likely due to missing support for quantized "
f"`{tuple_node.kind()}`. Please create an issue on {_constants.PYTORCH_GITHUB_ISSUES_URL}",
tuple_value,
)
unpacked = tuple(tuple_node.inputs())
assert len(unpacked) == 3 or len(unpacked) == 4
return unpacked


# Check if list_value is output from prim::ListConstruct
# This is usually called before _unpack_list to ensure the list can be unpacked.
def _is_packed_list(list_value: _C.Value) -> bool:
Expand Down Expand Up @@ -349,44 +370,57 @@ def q_foo(g, x, y):
"""

def decorator(fn):
fn._scale = scale
fn._zero_point = zero_point

@functools.wraps(fn)
def wrapper(g, *args, **kwargs):
_scale = fn._scale
if _scale is not None:
_scale = g.op("Constant", value_t=torch.tensor(_scale))
_zero_point = fn._zero_point
if _zero_point is not None:
_zero_point = g.op("Constant", value_t=torch.tensor(_zero_point))
nonlocal scale
nonlocal zero_point
if scale is not None:
_scale = g.op("Constant", value_t=torch.tensor(scale))
else:
_scale = None
if zero_point is not None:
_zero_point = g.op("Constant", value_t=torch.tensor(zero_point))
else:
_zero_point = None

# Support variable length arguments by marking unspecified ones as non-quantized
arg_q_descriptors_extended = arg_q_descriptors + (False,) * (
len(args) - len(arg_q_descriptors)
)
descriptor_args = tuple(zip(arg_q_descriptors_extended, args))

# Run regular symbolic function if none of the argument is QTensor.
if not any(
(descriptor and arg.node().kind() == "prim::TupleConstruct")
(descriptor and _is_value(arg) and _is_tuple_construct(arg))
for descriptor, arg in descriptor_args
):
return fn(g, *args, **kwargs)

dequantized_args = []
# Dequantize arguments that are quantized
non_quantized_args = []
for descriptor, arg in descriptor_args:
if descriptor:
dequantized_arg, scale, zero_point, _ = dequantize_helper(g, arg)
dequantized_args.append(dequantized_arg)
if descriptor and _is_value(arg) and _is_tuple_construct(arg):
# Quantized arg is a tuple of (value, scale, zero_point)
dequantized_arg, arg_scale, arg_zero_point, _ = dequantize_helper(
g, arg
)
non_quantized_args.append(dequantized_arg)
# Set scale and zero_point to the first quantized input if not already set
if _scale is None:
_scale = scale
_scale = arg_scale
if _zero_point is None:
_zero_point = zero_point
_zero_point = arg_zero_point
else:
dequantized_args.append(arg)
# Non-quantized arg
non_quantized_args.append(arg)
# TODO(justinchuby): Only single output is supported for now. We may want to
# support multiple outputs in the future.
output = fn(g, *dequantized_args, **kwargs)
output = fn(g, *non_quantized_args, **kwargs)

assert _scale is not None, "Bug: Scale must be set for quantized operator"
assert (
_zero_point is not None
), "Bug: Zero point must be set for quantized operator"

return quantize_helper(g, output, _scale, _zero_point)

Expand Down Expand Up @@ -472,6 +506,10 @@ def _is_scalar_list(x: _C.Value) -> bool:
)


def _is_tuple_construct(x: _C.Value) -> bool:
return x.node().kind() == "prim::TupleConstruct"


def is_caffe2_aten_fallback() -> bool:
return (
GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
Expand Down Expand Up @@ -1376,13 +1414,15 @@ def dequantize_helper(
Args:
g: Graph, the ONNX IR graph that is under construction.
qtensor: torch._C.Value, either a tuple of (quantized_tensor, scale, zero_point) for per tensor quantization,
or (quantized_tensor, scale, zero_point, axis) for per channel quantization.
Representing the quantized tensor.
qdtype: torch.onnx.TensorProtoDataType default None, if not None, represents the data type of quantized tensor.
It must be either torch.onnx.TensorProtoDataType.UINT8 or torch.onnx.TensorProtoDataType.INT8.
qtensor: torch._C.Value, either a tuple of (quantized_tensor, scale, zero_point)
for per tensor quantization, or
(quantized_tensor, scale, zero_point, axis) for per channel quantization,
representing the quantized tensor.
qdtype: torch.onnx.TensorProtoDataType default None, if not None, represents the
data type of quantized tensor. It must be either
torch.onnx.TensorProtoDataType.UINT8 or torch.onnx.TensorProtoDataType.INT8.
"""
unpacked_qtensors = _unpack_tuple(qtensor)
unpacked_qtensors = _unpack_quantized_tensor(qtensor)
tensor, scale, zero_point = unpacked_qtensors[:3]
axis = unpacked_qtensors[3] if len(unpacked_qtensors) >= 4 else None
axis_i = _get_const(axis, "i", "axis")
Expand Down Expand Up @@ -1430,6 +1470,9 @@ def quantize_helper(
zero_point: torch._C.Value, quantized zero point.
axis: Optional[torch._C.Value] default None, if None, represents per tensor quantization.
Otherwise, represents per channel quantization, along given axis.
Returns:
A TupleConstruct storing information of the quantized tensor.
"""
if (
axis is not None
Expand Down
51 changes: 51 additions & 0 deletions torch/onnx/symbolic_opset10.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,57 @@ def hardswish(g, x, op_scale, op_zero_point):

return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)

@staticmethod
def sigmoid(g, x, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)

output = opset9.sigmoid(g, x)

return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)

@staticmethod
def leaky_relu(g, x, negative_slope, inplace, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)

output = opset9.leaky_relu(g, x, negative_slope, inplace)

return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)

@staticmethod
def layer_norm(g, x, normalized_shape, weight, bias, eps, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)

output = opset9.layer_norm(g, x, normalized_shape, weight, bias, eps, False)

return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)

@staticmethod
def group_norm(g, x, num_groups, weight, bias, eps, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)

output = opset9.group_norm(g, x, num_groups, weight, bias, eps, False)

return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)

@staticmethod
@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v")
def instance_norm(
g,
q_input,
weight,
bias,
eps,
op_scale,
op_zero_point,
):
input, _, _, _ = symbolic_helper.dequantize_helper(g, q_input)

output = opset9.instance_norm(
g, input, weight, bias, None, None, False, 0, eps, False
)

return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)

@staticmethod
def conv2d_relu(
g,
Expand Down
4 changes: 3 additions & 1 deletion torch/onnx/symbolic_opset11.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,9 @@
]


@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "f", "f")
def hardtanh(g, self, min_val, max_val):
def hardtanh(g, self: _C.Value, min_val: float, max_val: float):
dtype = self.type().scalarType()
if dtype is None:
scalar_type = _type_utils.JitScalarType.FLOAT
Expand Down Expand Up @@ -189,6 +190,7 @@ def relu6(g, input):


# Opset 11 gather accepts negative indices
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "i", "v")
def select(g, self, dim, index):
return g.op("Gather", self, index, axis_i=dim)
Expand Down
Loading

0 comments on commit 80cfafc

Please sign in to comment.