Skip to content

Commit

Permalink
[PT2][Quant] Enable symbolic shape in linear quantization (pytorch#10…
Browse files Browse the repository at this point in the history
…4473)

When tracing with symbolic shapes, arbitrary sym_size nodes can appear in the
graph. Earlier changes did not account for this and quantizer fails to annotate
the right nodes. This diff fixes that by not annotating sym_size nodes, which
should really not be relevant for quantization.

As next steps, we should validate in quant workflow that a) sym_int nodes are not
being quantized and b) add similar support, as this diff, for generic
annotations

Differential Revision: [D47132050](https://our.internmc.facebook.com/intern/diff/D47132050/)
Pull Request resolved: pytorch#104473
Approved by: https://github.com/jerryzh168
  • Loading branch information
kimishpatel authored and pytorchmergebot committed Jul 1, 2023
1 parent 4e27e6c commit bd0f0f4
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 34 deletions.
55 changes: 43 additions & 12 deletions test/quantization/pt2e/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,17 +221,19 @@ def _test_quantizer(
expected_node_list=None,
check_against_fx_quant=False,
fx_qconfig_mapping=None,
export_with_dynamic_shape=False,
):
m_eager = model.eval()

# program capture
m = copy.deepcopy(m_eager)
m, guards = torchdynamo.export(
m,
*copy.deepcopy(example_inputs),
aten_graph=True,
tracing_mode="real",
)
with torchdynamo.config.patch(dynamic_shapes=export_with_dynamic_shape):
m, guards = torchdynamo.export(
m,
*copy.deepcopy(example_inputs),
aten_graph=True,
tracing_mode="symbolic" if export_with_dynamic_shape else "real",
)

m = prepare_pt2e_quantizer(m, quantizer)
# Calibrate
Expand All @@ -258,12 +260,13 @@ def _test_quantizer(
m_fx = _convert_to_reference_decomposed_fx(
m_fx, backend_config=backend_config
)
m_fx, guards = torchdynamo.export(
m_fx,
*copy.deepcopy(example_inputs),
aten_graph=True,
tracing_mode="real",
)
with torchdynamo.config.patch(dynamic_shapes=export_with_dynamic_shape):
m_fx, guards = torchdynamo.export(
m_fx,
*copy.deepcopy(example_inputs),
aten_graph=True,
tracing_mode="symbolic" if export_with_dynamic_shape else "real",
)
node_occurrence = {}
for k, v in PT2EQuantizationTestCase._MAP_TO_FX_TRACED_OPS.items():
if k in expected_node_occurrence:
Expand Down Expand Up @@ -1094,6 +1097,34 @@ def test_qnnpack_quantizer_conv_linear(self):
qconfig_mapping,
)

def test_qnnpack_quantizer_linear_with_dynamic_shape(self):
quantizer = QNNPackQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m_eager = TestHelperModules.TwoLinearModule().eval()

# Test with 2d inputs
example_inputs_3d = (torch.randn(9, 10, 8),)
node_occurrence = {
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
torch.ops.quantized_decomposed.quantize_per_channel.default: 2,
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
}
qconfig = default_per_channel_symmetric_qnnpack_qconfig
qconfig_mapping = QConfigMapping().set_global(qconfig)
self._test_quantizer(
m_eager,
example_inputs_3d,
quantizer,
node_occurrence,
[],
True,
qconfig_mapping,
export_with_dynamic_shape=True,
)

def test_qnnpack_quantizer_obs_sharing_ops(self):
quantizer = QNNPackQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
Expand Down
64 changes: 42 additions & 22 deletions torch/ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from torch.ao.quantization._pt2e.quantizer.utils import (
_annotate_input_qspec_map,
_annotate_output_qspec,
_is_sym_size_node,
_node_only_used_for_sym_size,
get_bias_qspec,
get_input_act_qspec,
get_output_act_qspec,
Expand All @@ -40,10 +42,10 @@
OperatorPatternType,
QuantizationAnnotation,
QuantizationConfig,
QuantizationSpecBase,
QuantizationSpec,
SharedQuantizationSpec,
QuantizationSpecBase,
Quantizer,
SharedQuantizationSpec,
)


Expand Down Expand Up @@ -517,9 +519,18 @@ def _annotate_linear(
bias_qspec = get_bias_qspec(quantization_config)
for module_or_fn_type, partitions in module_partitions.items():
for p in partitions:
if len(p.input_nodes) > 1:
raise ValueError(f"More than one input node found for {module_or_fn_type} partition")
act_node = p.input_nodes[0]
act_nodes = [
n
for n in p.input_nodes
if not _node_only_used_for_sym_size(n, p.nodes)
]
if len(act_nodes) > 1:
raise ValueError(
f"Multiple activation nodes found for partition {p} {act_nodes}"
)
if len(act_nodes) == 0:
raise ValueError(f"No activation node found for partition {p}")
act_node = act_nodes[0]
output_node = p.output_nodes[0]
weight_node = None
bias_node = None
Expand All @@ -533,14 +544,21 @@ def _annotate_linear(
raise ValueError("No weight found in Linear pattern")
# find use of act node within the matched pattern
act_use_node = None
for node in p.nodes:
if node in act_node.users: # type: ignore[union-attr]
act_use_node = node
break
if act_use_node is None:
# When doing tracing with dynamic shape, we end up with sym_size nodes
# This nodes do not need quantization, so skip those.
# We can also have quant workflow throw exception when sym_size nodes
# are annotated.
# This is not specific to linear, so in future diffs we should streamline
# this.
act_node_users = list(
filter((lambda x: (_is_sym_size_node(x) is False)), act_node.users)
)
act_use_node_in_p = set(act_node_users).intersection(set(p.nodes))
if len(act_use_node_in_p) != 1:
raise ValueError(
"Could not find an user of act node within matched pattern."
f"Could not find a valid use of act node. All uses {act_use_node_in_p}"
)
act_use_node = act_use_node_in_p.pop()
if _is_annotated([act_use_node]) is False: # type: ignore[list-item]
_annotate_input_qspec_map(
act_use_node,
Expand All @@ -560,9 +578,7 @@ def _annotate_linear(
def _annotate_gru(
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
) -> None:
gru_partitions = get_source_partitions(
gm.graph, [torch.nn.GRU]
)
gru_partitions = get_source_partitions(gm.graph, [torch.nn.GRU])
gru_partitions = list(itertools.chain(*gru_partitions.values()))
for gru_partition in gru_partitions:
output_nodes = gru_partition.output_nodes
Expand All @@ -581,7 +597,7 @@ def _annotate_gru(
input_qspec_map={
input_act: get_input_act_qspec(quantization_config),
},
_annotated=True
_annotated=True,
)

hidden_state = input_nodes[1]
Expand All @@ -592,7 +608,7 @@ def _annotate_gru(
input_qspec_map={
hidden_state: get_input_act_qspec(quantization_config),
},
_annotated=True
_annotated=True,
)

assert len(output_nodes) == 2, "expecting GRU to have two outputs"
Expand Down Expand Up @@ -624,9 +640,11 @@ def _annotate_maxpool2d(
assert isinstance(input_act, Node)

# only annotate maxpool when the output of the input node is annotated
if "quantization_annotation" not in input_act.meta or \
not input_act.meta["quantization_annotation"]._annotated or \
input_act.meta["quantization_annotation"].output_qspec is None:
if (
"quantization_annotation" not in input_act.meta
or not input_act.meta["quantization_annotation"]._annotated
or input_act.meta["quantization_annotation"].output_qspec is None
):
continue
# input and output of maxpool will share quantization parameter with input of maxpool
act_qspec = SharedQuantizationSpec(input_act)
Expand Down Expand Up @@ -663,9 +681,11 @@ def _annotate_input_out_obs_sharing_op(

# only annotate input output sharing operator
# when the output of the input node is annotated
if "quantization_annotation" not in input_act.meta or \
not input_act.meta["quantization_annotation"]._annotated or \
input_act.meta["quantization_annotation"].output_qspec is None:
if (
"quantization_annotation" not in input_act.meta
or not input_act.meta["quantization_annotation"]._annotated
or input_act.meta["quantization_annotation"].output_qspec is None
):
continue

act_qspec = SharedQuantizationSpec(input_act)
Expand Down
31 changes: 31 additions & 0 deletions torch/ao/quantization/_pt2e/quantizer/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

import torch
from torch.ao.quantization._pt2e.quantizer.quantizer import (
QuantizationAnnotation,
Expand Down Expand Up @@ -79,3 +81,32 @@ def _annotate_output_qspec(node: Node, qspec):
)
quantization_annotation.output_qspec = qspec
node.meta["quantization_annotation"] = quantization_annotation


def _is_sym_size_node(node: Node):
return node.op == "call_function" and node.target == torch.ops.aten.sym_size


def _node_only_used_for_sym_size(node: Node, partition_nodes: List[Node]):
"""
This utility is used to handle cases when dynami_shape=True tracing leads
to symint nodes in the pattern of linear module. In those cases, we need to
distinguish between the nodes that are in input for just extracting value of
some dimentions (and symint nodes) vs. the one that is activation.
For example:
graph(x, y, weight):
size_0 = torch.ops.aten.sym_size([x], [0])
size_1 = torch.ops.aten.sym_size([y], [1])
view_size = size_0 * size_1
size_3 = torch.ops.aten.sym_size([x], [2])
vie_out = torch.ops.aten.view(x, [view_size, size_3])
return mm(view_out, weight)
In the example above y node is not actual input. It exist only to extract size_1
"""
if _is_sym_size_node(node):
return True

return all(
((user not in partition_nodes) or _is_sym_size_node(user))
for user in node.users
)

0 comments on commit bd0f0f4

Please sign in to comment.