Skip to content

Commit

Permalink
[quant][fx] Add an option in convert_fx to accept qconfig_dict to ski…
Browse files Browse the repository at this point in the history
…p quantization (pytorch#66878)

Summary:
Pull Request resolved: pytorch#66878

Currently convert_fx quantizes all layers that have been prepared, depending on the prepare qconfig_dict
This PR adds support to accept a variation of qconfig_dict in convert_fx that can be used to specify skip quantizing certain layers

This can help with prepare/observe all operators, quantize a subset of them (based on quantization error), to avoid preparing multiple times.

The qconfig_dict passed to convert_fx can only have the values set to `None`, with the keys being the same as what is allowed in the prepare qconfig_dict

Test Plan:
python test/test_quantization.py TestQuantizeFx.test_convert_qconfig_dict

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D31808247

fbshipit-source-id: a4f5dca1090f0083fc3fea14aff56924033eb24f
  • Loading branch information
supriyar authored and facebook-github-bot committed Oct 23, 2021
1 parent d13829e commit 8460fa5
Show file tree
Hide file tree
Showing 6 changed files with 291 additions and 63 deletions.
94 changes: 94 additions & 0 deletions test/quantization/fx/test_quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3182,6 +3182,100 @@ def forward(self, x):
# checking result match
self.assertEqual(out_ref, out)

def test_convert_qconfig_dict(self):
class Linear(torch.nn.Module):
def __init__(self):
super().__init__()
self.w = torch.ones(5, 5)
self.b = torch.zeros(5)

def forward(self, x):
return torch.nn.functional.linear(x, self.w, self.b)


class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.mods1 = torch.nn.Sequential(
Linear(),
Linear()
)
self.mods3 = torch.nn.Linear(5, 5)

def forward(self, x):
x = self.mods1(x)
x = torch.add(x, 4)
z = torch.mul(x, 5)
x = self.mods3(z)
return x

model = M().train()

for check in ["module_name", "object_type"]:
qconfig_dict = {"": None,
"object_type": [
(nn.functional.linear, get_default_qat_qconfig("fbgemm")),
(torch.add, get_default_qat_qconfig("fbgemm")),
(nn.Linear, get_default_qat_qconfig("fbgemm")),
],
}
prepared = prepare_qat_fx(model, qconfig_dict)
prepared(torch.rand(5, 5))
if check == "module_name":
convert_qconfig_dict = {"": None,
"object_type": [
(nn.functional.linear, get_default_qat_qconfig("fbgemm")),
(torch.add, get_default_qat_qconfig("fbgemm")),
(nn.Linear, get_default_qat_qconfig("fbgemm")),
],
"module_name": [("mods1.0", None)]}

node_occurrence = {
ns.call_function(torch.quantize_per_tensor): 2,
ns.call_function(torch.nn.functional.linear): 1,
ns.call_function(torch.ops.quantized.linear): 1,
ns.call_function(torch.ops.quantized.add): 1,
ns.call_method("dequantize"): 2
}
order_check = [
ns.call_function(torch.nn.functional.linear),
ns.call_function(torch.quantize_per_tensor),
ns.call_function(torch.ops.quantized.linear),
ns.call_function(torch.ops.quantized.add),
ns.call_method("dequantize"),
ns.call_function(torch.quantize_per_tensor),
ns.call_module(nnq.Linear),
ns.call_method("dequantize"),
]
elif check == "object_type":
convert_qconfig_dict = {"": None,
"object_type": [
(nn.functional.linear, get_default_qat_qconfig("fbgemm")),
(torch.add, get_default_qat_qconfig("fbgemm")),
(nn.Linear, None),
]}

node_occurrence = {
ns.call_function(torch.quantize_per_tensor): 1,
ns.call_function(torch.ops.quantized.linear): 2,
ns.call_function(torch.ops.quantized.add): 1,
ns.call_method("dequantize"): 1
}
order_check = [
ns.call_function(torch.quantize_per_tensor),
ns.call_function(torch.ops.quantized.linear),
ns.call_function(torch.ops.quantized.linear),
ns.call_function(torch.ops.quantized.add),
ns.call_method("dequantize"),
ns.call_module(nn.Linear),
]

converted = convert_fx(prepared, qconfig_dict=convert_qconfig_dict)
converted(torch.rand(5, 5))
self.checkGraphModuleNodes(
converted,
expected_node_occurrence=node_occurrence,
expected_node_list=order_check)

@skipIfNoFBGEMM
class TestQuantizeFxOps(QuantizationTestCase):
Expand Down
64 changes: 62 additions & 2 deletions torch/ao/quantization/fx/convert.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Dict, Tuple, List, Callable, Optional, Union
from collections import defaultdict
import copy
import torch
from torch.fx import (
GraphModule,
Expand All @@ -12,7 +13,7 @@
)
from torch.fx.node import Argument
from .quantization_types import Pattern
from ..qconfig import QConfigAny
from ..qconfig import QConfigAny, qconfig_equals
from .match_utils import (
find_matches,
)
Expand All @@ -24,6 +25,13 @@
from .quantization_patterns import (
QuantizeHandler,
)
from .qconfig_utils import (
convert_dict_to_ordered_dict,
generate_qconfig_map,
compare_prepare_convert_qconfig_dict,
update_qconfig_for_fusion,
update_qconfig_for_qat,
)
from ._equalize import update_obs_for_equalization, convert_eq_obs
from .utils import (
is_get_tensor_info_node,
Expand All @@ -46,6 +54,9 @@
)

from .lower_to_fbgemm import lower_to_fbgemm
from ..quantization_mappings import (
DEFAULT_QAT_MODULE_MAPPINGS,
)

# weight prepacking ops
WEIGHT_PREPACK_OPS = {
Expand Down Expand Up @@ -131,6 +142,24 @@ def load_arg(a):
quantized = QuantizedGraphModule(quantized_root, folded_graph, quantized_root.preserved_attr_names)
return quantized

def remove_quant_dequant_pairs(quantized: QuantizedGraphModule) -> QuantizedGraphModule:
quantized_root = quantized
for node in quantized.graph.nodes:
if node.op == "call_function" and node.target in [torch.quantize_per_tensor, torch.quantize_per_channel]:
users = list(node.users)
user = users[0] if users else None
if len(users) == 1 and user.op == "call_method" and user.target == "dequantize":
user.replace_all_uses_with(node.args[0])
quantized.graph.erase_node(user)
orig_args = list(node.args)
quantized.graph.erase_node(node)
for arg in orig_args:
if isinstance(arg, Node) and len(list(arg.users)) == 0:
quantized.graph.erase_node(arg)

quantized = QuantizedGraphModule(quantized_root, quantized.graph, quantized_root.preserved_attr_names)
return quantized

def restore_state(
observed: GraphModule
) -> Tuple[Dict[Pattern, QuantizeHandler], Dict[str, Tuple[str, type]], Dict[str, Any]]:
Expand All @@ -145,7 +174,8 @@ def restore_state(
def convert(model: GraphModule, is_reference: bool = False,
convert_custom_config_dict: Dict[str, Any] = None,
is_standalone_module: bool = False,
_remove_qconfig_flag: bool = True) -> QuantizedGraphModule:
_remove_qconfig_flag: bool = True,
convert_qconfig_dict: Dict[str, Any] = None) -> QuantizedGraphModule:
""" standalone_module means it a submodule that is not inlined in
parent module, and will be quantized separately as one unit.
Expand Down Expand Up @@ -178,6 +208,29 @@ def convert(model: GraphModule, is_reference: bool = False,
# the same activation_post_process module instance but different names
modules = dict(model.named_modules(remove_duplicate=False))

# TODO refactor this code once we update the prepare logic to have additional information on
# which graph nodes have been observed and share that with convert to decide which observers to ignore.
if convert_qconfig_dict:
prepare_qconfig_dict: Dict[str, Dict[Any, Any]] = model._qconfig_dict # type: ignore[assignment]
modules_copy = copy.deepcopy(modules)
convert_dict_to_ordered_dict(convert_qconfig_dict)
if model._is_training:
additional_qat_module_mapping = prepare_custom_config_dict.get(
"additional_qat_module_mapping", {})
convert_qconfig_dict = update_qconfig_for_qat(convert_qconfig_dict, additional_qat_module_mapping)
convert_qconfig_dict = update_qconfig_for_fusion(model, convert_qconfig_dict)

compare_prepare_convert_qconfig_dict(prepare_qconfig_dict, convert_qconfig_dict) # type: ignore
convert_qconfig_map = generate_qconfig_map(model, modules_copy, model.graph, convert_qconfig_dict, node_name_to_scope)
# check the convert_qconfig_map generated and ensure that all the values either match what was set in prepare qconfig_map
# or are set to None in the convert_qconfig_map.
for k, v in qconfig_map.items():
assert k in convert_qconfig_map, 'Expected key {} in convert qconfig_map'.format(k)
if convert_qconfig_map[k] is not None:
assert qconfig_equals(v, convert_qconfig_map[k]), 'Expected k {} to have the same value in prepare qconfig_dict \
and convert qconfig_dict, found {} updated to {}.'.format(k, v, convert_qconfig_map[k])
qconfig_map = convert_qconfig_map

custom_module_classes = get_custom_module_class_keys(
convert_custom_config_dict,
"observed_to_quantized_custom_module_class")
Expand Down Expand Up @@ -445,6 +498,12 @@ def insert_quantize_node(node: Node, modules: Dict[str, torch.nn.Module]) -> Non
result = quantized_graph.node_copy(
node, load_non_quantized)
quantized = False
# If there are QAT swapped modules in the graph that we don't want to quantize, rever them back to FP32 ones.
if node.op == 'call_module' and type(modules[node.target]) in DEFAULT_QAT_MODULE_MAPPINGS.values():
float_mod = modules[node.target].to_float()
setattr(model, node.name, float_mod)
with model.graph.inserting_before(node):
new_float_node = model.graph.create_node('call_module', node.name, node.args, node.kwargs)
else:
assert obj is not None
# We will get whether the output is quantized or not before
Expand Down Expand Up @@ -538,4 +597,5 @@ def load_arg_remove(a: Argument) -> Argument:
if not is_reference:
model = fold_weight(model, node_name_to_scope)
model = lower_to_fbgemm(model)
model = remove_quant_dequant_pairs(model)
return model
4 changes: 3 additions & 1 deletion torch/ao/quantization/fx/graph_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, p
'_qconfig_map',
'_prepare_custom_config_dict',
'_equalization_qconfig_map',
'_node_name_to_scope']).union(preserved_attr_names)
'_node_name_to_scope',
'_qconfig_dict',
'_is_training']).union(preserved_attr_names)
preserved_attrs = {attr: getattr(root, attr) for attr in self.preserved_attr_names if hasattr(root, attr)}
super().__init__(root, graph)
for attr in preserved_attrs:
Expand Down
64 changes: 8 additions & 56 deletions torch/ao/quantization/fx/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
from ..observer import (
ObserverBase,
)
from ..qconfig import QConfigAny, qconfig_equals
from ..qconfig import QConfigAny
from .qconfig_utils import (
convert_dict_to_ordered_dict,
generate_qconfig_map,
get_flattened_qconfig_dict,
update_qconfig_for_fusion,
update_qconfig_for_qat,
)

from .quantization_patterns import (
Expand Down Expand Up @@ -61,8 +63,6 @@
BIAS_INDEX_DICT,
)

from ..fuser_method_mappings import DEFAULT_OP_LIST_TO_FUSER_METHOD

from ..quantization_mappings import (
get_default_qat_module_mappings,
)
Expand Down Expand Up @@ -146,58 +146,6 @@ def qat_swap_modules(
get_default_qat_module_mappings(), additional_qat_module_mapping)
convert(root, mapping=all_mappings, inplace=True, remove_qconfig=False)

def update_qconfig_for_qat(
qconfig_dict: Any,
additional_qat_module_mapping: Dict[Callable, Callable]
) -> Any:
"""
Update the qconfig_dict to account for module swaps during QAT.
During QAT we perform a module swap on the nn.Module types to the corresponding nn.qat.modules types.
"""
all_qat_mappings = get_combined_dict(
get_default_qat_module_mappings(), additional_qat_module_mapping)
object_type_dict = qconfig_dict.get("object_type", None)
new_object_type_dict = object_type_dict.copy()
for k, v in new_object_type_dict.items():
if k in all_qat_mappings:
object_type_dict[all_qat_mappings[k]] = v
return qconfig_dict

def update_qconfig_for_fusion(
model: GraphModule,
qconfig_dict: Any,
) -> Any:
"""
Update the qconfig_dict to account for fused modules such as LinearReLU.
"""
object_type_dict = qconfig_dict.get("object_type", None)
if object_type_dict is None:
return qconfig_dict

modules = dict(model.named_modules())

for node in model.graph.nodes:
if node.op == 'call_module':
module_type = type(modules[str(node.target)])
if module_type not in list(DEFAULT_OP_LIST_TO_FUSER_METHOD.values()):
continue

for ops, fuser in DEFAULT_OP_LIST_TO_FUSER_METHOD.items():
if module_type == fuser:
fused_qconfig = object_type_dict.get(ops[0], None)

# Raise an error if the modules in the fused module have
# different qconfigs specified in the qconfig_dict
for op in ops:
if not qconfig_equals(object_type_dict.get(op, None), fused_qconfig):
raise LookupError("During fusion, we need to specify the same " +
f"qconfigs for both modules in {module_type}.")

if fused_qconfig is not None:
object_type_dict[module_type] = fused_qconfig

return qconfig_dict

def insert_observer(
node: Node,
observed_op: Node,
Expand Down Expand Up @@ -1091,13 +1039,17 @@ def save_state(
patterns: Dict[Pattern, QuantizeHandler],
prepare_custom_config_dict: Dict[str, Any],
equalization_qconfig_map: Dict[str, Any],
qconfig_dict: Dict[str, Dict[Any, Any]],
is_training: bool,
) -> None:
observed._patterns = patterns # type: ignore[assignment]
observed._qconfig_map = qconfig_map # type: ignore[assignment]
observed._prepare_custom_config_dict = \
prepare_custom_config_dict # type: ignore[assignment]
observed._node_name_to_scope = node_name_to_scope # type: ignore[assignment]
observed._equalization_qconfig_map = equalization_qconfig_map # type: ignore[assignment]
observed._qconfig_dict = qconfig_dict # type: ignore[assignment]
observed._is_training = is_training # type: ignore[assignment]

def prepare(
model: GraphModule,
Expand Down Expand Up @@ -1209,7 +1161,7 @@ def prepare(
input_quantized_idxs, output_quantized_idxs)

save_state(model, qconfig_map, node_name_to_scope, patterns,
prepare_custom_config_dict, equalization_qconfig_map)
prepare_custom_config_dict, equalization_qconfig_map, qconfig_dict, model.training)
preserved_attributes = set(prepare_custom_config_dict.get("preserved_attributes", []))
model = ObservedGraphModule(model, model.graph, preserved_attributes)
if is_standalone_module:
Expand Down
Loading

0 comments on commit 8460fa5

Please sign in to comment.