Skip to content

Commit

Permalink
[quant][pt2e] Add fold_quantize=True for all convert_pt2e calls (pyto…
Browse files Browse the repository at this point in the history
…rch#117797)

Summary: In preparation for enabling fold_quantize=True by default

Test Plan: CI

Differential Revision: D52879612

Pull Request resolved: pytorch#117797
Approved by: https://github.com/andrewor14
  • Loading branch information
jerryzh168 authored and pytorchmergebot committed Jan 24, 2024
1 parent 90b3cf3 commit af1ebc4
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 31 deletions.
2 changes: 1 addition & 1 deletion test/quantization/pt2e/test_duplicate_dq.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _test_duplicate_dq(
m = prepare_pt2e(m, quantizer)
# Calibrate
m(*example_inputs)
m = convert_pt2e(m)
m = convert_pt2e(m, fold_quantize=True)

pt2_quant_output = m(*example_inputs)
for n in m.graph.nodes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,6 @@ def test_quantize_pt2e_preserve_handle(self):
debug_handle_map = _extract_conv2d_pattern_debug_handle_map(m)
self.assertEqual(debug_handle_map, debug_handle_map_ref)
m(*example_inputs)
m = convert_pt2e(m)
m = convert_pt2e(m, fold_quantize=True)
debug_handle_map = _extract_conv2d_pattern_debug_handle_map(m)
self.assertEqual(debug_handle_map, debug_handle_map_ref)
55 changes: 26 additions & 29 deletions test/quantization/pt2e/test_quantize_pt2e_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _verify_symmetric_xnnpack_qat_numerics_helper(
if verify_convert:
# We don't want to impose any ordering requirements between move_exported_model_to_eval and convert_pt2e
torch.ao.quantization.move_exported_model_to_eval(model_pt2e)
model_pt2e = convert_pt2e(model_pt2e)
model_pt2e = convert_pt2e(model_pt2e, fold_quantize=True)
quant_result_pt2e = model_pt2e(*example_inputs)
model_fx.eval()
model_fx = _convert_to_reference_decomposed_fx(
Expand Down Expand Up @@ -631,7 +631,7 @@ def forward(self, x):
m = capture_pre_autograd_graph(m, example_inputs)
m = prepare_qat_pt2e(m, quantizer)
m(*example_inputs)
m = convert_pt2e(m)
m = convert_pt2e(m, fold_quantize=True)

# Extract the conv and relu nodes (bn was folded into conv)
first_conv, first_relu, second_conv, second_relu = None, None, None, None
Expand All @@ -649,24 +649,26 @@ def forward(self, x):
# Extract the conv weight and bias nodes
def get_conv_weight_and_bias(conv_node: torch.fx.Node):
weight_dq_node = conv_node.args[1]
weight_q_node = weight_dq_node.args[0]
weight_node = weight_q_node.args[0]
qweight_node = weight_dq_node.args[0]
bias_node = conv_node.args[2]
assert isinstance(weight_node, torch.fx.Node)
assert isinstance(qweight_node, torch.fx.Node)
assert isinstance(bias_node, torch.fx.Node)
return (weight_node, bias_node)
return (qweight_node, bias_node)

first_conv_weight, first_conv_bias = get_conv_weight_and_bias(first_conv)
second_conv_weight, second_conv_bias = get_conv_weight_and_bias(second_conv)
first_conv_qweight, first_conv_bias = get_conv_weight_and_bias(first_conv)
second_conv_qweight, second_conv_bias = get_conv_weight_and_bias(second_conv)

# Assert that each set of conv, conv weight, and conv bias are in the same partition
def get_source_fn(node: torch.fx.Node):
# E.g. [('l__self___backbone1_conv', <class 'torch.nn.modules.conv.Conv2d'>)]
return node.meta["source_fn_stack"][0][0]

self.assertEqual(get_source_fn(first_conv), get_source_fn(first_conv_weight))
# we don't preserve this is quantized weight currently since it's folded
# but user can attach "quantization_tag" to the node and it will be preserved
# self.assertEqual(get_source_fn(first_conv), get_source_fn(first_conv_qweight))
# self.assertEqual(get_source_fn(second_conv), get_source_fn(second_conv_qweight))

self.assertEqual(get_source_fn(first_conv), get_source_fn(first_conv_bias))
self.assertEqual(get_source_fn(second_conv), get_source_fn(second_conv_weight))
self.assertEqual(get_source_fn(second_conv), get_source_fn(second_conv_bias))

# Assert that different sets of convs and relus have different partitions
Expand All @@ -688,7 +690,7 @@ def test_qat_conv_bn_bias_derived_qspec(self):
quantizer = ConvBnDerivedBiasQuantizer()
m = prepare_qat_pt2e(m, quantizer)
m(*example_inputs)
m = convert_pt2e(m)
m = convert_pt2e(m, fold_quantize=True)
m(*example_inputs)

# Assert that both weight and bias are quantized
Expand All @@ -703,15 +705,15 @@ def test_qat_conv_bn_bias_derived_qspec(self):
bias_dq.target,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
)
weight_q = weight_dq.args[0]
bias_q = bias_dq.args[0]
weight_getattr = weight_dq.args[0]
bias_getattr = bias_dq.args[0]
self.assertEqual(
weight_q.target,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
weight_getattr.op,
"get_attr",
)
self.assertEqual(
bias_q.target,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
bias_getattr.op,
"get_attr",
)

# Assert that bias scale = weight scale * input scale
Expand All @@ -735,7 +737,7 @@ def test_qat_per_channel_weight_custom_dtype(self):
quantizer = ConvBnInt32WeightQuantizer()
m = prepare_qat_pt2e(m, quantizer)
m(*example_inputs)
m = convert_pt2e(m)
m = convert_pt2e(m, fold_quantize=True)
m(*example_inputs)

# Assert that conv weight is quantized per channel
Expand All @@ -745,23 +747,18 @@ def test_qat_per_channel_weight_custom_dtype(self):
weight_dq.target,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
)
weight_q = weight_dq.args[0]
weight_getattr = weight_dq.args[0]
self.assertEqual(
weight_q.target,
torch.ops.quantized_decomposed.quantize_per_channel.default,
weight_getattr.op,
"get_attr",
)

# Assert that args for the weight's quantize and dequantize ops
# Assert that args for the weight's dequantize ops
# are copied correctly after subgraph rewriting
(q_axis, q_qmin, q_qmax, q_dtype) = weight_q.args[3:]
(dq_axis, dq_qmin, dq_qmax, dq_dtype) = weight_dq.args[3:]
self.assertEqual(q_axis, 0)
self.assertEqual(dq_axis, 0)
self.assertEqual(q_qmin, 0)
self.assertEqual(dq_qmin, 0)
self.assertEqual(q_qmax, 2**31 - 1)
self.assertEqual(dq_qmax, 2**31 - 1)
self.assertEqual(q_dtype, torch.int32)
self.assertEqual(dq_dtype, torch.int32)


Expand Down Expand Up @@ -1002,7 +999,7 @@ def test_mixing_qat_ptq(self):
quantizer.set_global(quantization_config)
model_pt2e = prepare_pt2e(model_pt2e, quantizer)
after_prepare_result_pt2e = model_pt2e(*example_inputs)
model_pt2e = convert_pt2e(model_pt2e)
model_pt2e = convert_pt2e(model_pt2e, fold_quantize=True)
quant_result_pt2e = model_pt2e(*example_inputs)

exported_model = torch.export.export(model_pt2e, example_inputs)
Expand All @@ -1012,7 +1009,7 @@ def test_mixing_qat_ptq(self):
# 3 x linear: 1 for act, 1 for output
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor.default
): 9,
): 8,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
): 9,
Expand Down

0 comments on commit af1ebc4

Please sign in to comment.