Skip to content

Commit

Permalink
[quant][pt2e][xnnpack] XNNPACKQuantizer skip quantization for input a…
Browse files Browse the repository at this point in the history
…nd output to workaround histogram observer problem (pytorch#113405)

Summary:
att, this is because histogram observer does not work for a corner case in mobilebert (observing a scalar tensor of float32 max value)
because histc operator errors out when the value is larger than certain number

Test Plan:
python test/test_quantization.py -k test_mul_float32_max

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: pytorch#113405
Approved by: https://github.com/mcr229
  • Loading branch information
jerryzh168 authored and pytorchmergebot committed Dec 2, 2023
1 parent 7bbc19a commit 8f16401
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 0 deletions.
25 changes: 25 additions & 0 deletions test/quantization/pt2e/test_xnnpack_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,31 @@ def test_add_mul_scalar(self):
node_list,
)

def test_mul_float32_max(self):
class M(torch.nn.Module):
def forward(self, x):
return x * 3.4028235e38

quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(quantization_config)
example_inputs = (torch.randn(1, 3, 5, 5),)
# not quantized
node_occurrence = {
torch.ops.quantized_decomposed.quantize_per_tensor.default: 0,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0,
}
node_list = [
torch.ops.aten.mul.Tensor,
]
self._test_quantizer(
M(),
example_inputs,
quantizer,
node_occurrence,
node_list,
)


# TODO: express this using self._test_quantizer, add test for inception_v4
class TestXNNPACKQuantizerModels(PT2EQuantizationTestCase):
Expand Down
28 changes: 28 additions & 0 deletions torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,18 @@ def _annotate_adaptive_avg_pool2d(
return annotated_partitions


def _is_input_large_scalar(node: Node, gm: torch.fx.GraphModule):
"""Check if input is a large scalar value. So that we can skip quantization for the node
since histc op (in HistogramObserver) only works for values up to certain upper bound
"""
if node.op == "get_attr":
tensor = getattr(gm, node.target) # type: ignore[arg-type]
# torch.histc works until this upper bound
HISTC_UPPER_BOUND = 3.4028235e15
return tensor.numel() == 1 and abs(tensor.item()) > HISTC_UPPER_BOUND
return False


@register_annotator("add_relu")
def _annotate_add_relu(
gm: torch.fx.GraphModule,
Expand Down Expand Up @@ -645,10 +657,14 @@ def _annotate_add_relu(
input_qspec_map = {}
input_act0 = add_node.args[0]
if isinstance(input_act0, Node):
if _is_input_large_scalar(input_act0, gm):
continue
input_qspec_map[input_act0] = input_act_qspec

input_act1 = add_node.args[1]
if isinstance(input_act1, Node):
if _is_input_large_scalar(input_act1, gm):
continue
input_qspec_map[input_act1] = input_act_qspec

add_node.meta["quantization_annotation"] = QuantizationAnnotation(
Expand Down Expand Up @@ -685,10 +701,14 @@ def _annotate_add(
input_qspec_map = {}
input_act0 = add_node.args[0]
if isinstance(input_act0, Node):
if _is_input_large_scalar(input_act0, gm):
continue
input_qspec_map[input_act0] = input_act_qspec

input_act1 = add_node.args[1]
if isinstance(input_act1, Node):
if _is_input_large_scalar(input_act1, gm):
continue
input_qspec_map[input_act1] = input_act_qspec

add_node.meta["quantization_annotation"] = QuantizationAnnotation(
Expand Down Expand Up @@ -728,10 +748,14 @@ def _annotate_mul_relu(
input_qspec_map = {}
input_act0 = mul_node.args[0]
if isinstance(input_act0, Node):
if _is_input_large_scalar(input_act0, gm):
continue
input_qspec_map[input_act0] = input_act_qspec

input_act1 = mul_node.args[1]
if isinstance(input_act1, Node):
if _is_input_large_scalar(input_act1, gm):
continue
input_qspec_map[input_act1] = input_act_qspec

mul_node.meta["quantization_annotation"] = QuantizationAnnotation(
Expand Down Expand Up @@ -768,10 +792,14 @@ def _annotate_mul(
input_qspec_map = {}
input_act0 = mul_node.args[0]
if isinstance(input_act0, Node):
if _is_input_large_scalar(input_act0, gm):
continue
input_qspec_map[input_act0] = input_act_qspec

input_act1 = mul_node.args[1]
if isinstance(input_act1, Node):
if _is_input_large_scalar(input_act1, gm):
continue
input_qspec_map[input_act1] = input_act_qspec

mul_node.meta["quantization_annotation"] = QuantizationAnnotation(
Expand Down

0 comments on commit 8f16401

Please sign in to comment.