Skip to content

Commit

Permalink
[ao][sparsity][fx] make quant prepare -> sparse prepare compose (pyto…
Browse files Browse the repository at this point in the history
…rch#81992)

Summary: sparse_prepare automatically composes with quantized prepare
even in cases with fusion. However, the convert step needed to be updated to handle parametrized
modules.

Test Plan: python test/test_ao_sparsity.py TestFxComposability

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: pytorch#81992
Approved by: https://github.com/jerryzh168
  • Loading branch information
HDCharles authored and pytorchmergebot committed Jul 27, 2022
1 parent e0faa02 commit 8533951
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 64 deletions.
187 changes: 129 additions & 58 deletions test/ao/sparsity/test_composability.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from torch import nn
from torch.ao import sparsity
from torch.testing._internal.common_utils import TestCase
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.sparsity import fqn_to_module

logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
Expand All @@ -20,50 +22,47 @@
"zeros_per_block": 4,
}

def _get_model_and_sparsifier_and_sparse_config(qconfig=None):
model = nn.Sequential(
nn.Linear(4, 4), # 0
nn.ReLU(),
nn.Linear(4, 4), # 2
nn.ReLU(),
tq.QuantStub(),
nn.Linear(4, 4), # 5
nn.ReLU(),
tq.DeQuantStub(),
)
if qconfig:
model[4].qconfig = qconfig
model[5].qconfig = qconfig

sparsifier = sparsity.WeightNormSparsifier(**sparse_defaults)

sparse_config = [
{
"tensor_fqn": '5.weight',
"sparsity_level": 0.7,
"sparse_block_shape": (1, 4),
"zeros_per_block": 4,
},
{"tensor_fqn": "0.weight"},
]
return model, sparsifier, sparse_config

def _squash_mask_calibrate_and_convert(model, sparsifier, input):
sparsifier.step()
sparsifier.squash_mask()
model(input)
tq.convert(model, inplace=True)

def _calculate_sparsity(tensor):
return ((tensor == 0).sum() / tensor.numel()).item()

# This series of tests are to check the composability goals for sparsity and quantization. Namely
# that performing quantization and sparsity model manipulations in various orderings
# does not cause problems
class TestComposability(TestCase):
def _get_model_and_sparsifier_and_sparse_config(self, qconfig=None):
model = nn.Sequential(
nn.Linear(4, 4), # 0
nn.ReLU(),
nn.Linear(4, 4), # 2
nn.ReLU(),
tq.QuantStub(),
nn.Linear(4, 4), # 5
nn.ReLU(),
tq.DeQuantStub(),
)
if qconfig is None:
model[4].qconfig = tq.get_default_qconfig("fbgemm")
model[5].qconfig = tq.get_default_qconfig("fbgemm")
else:
model[4].qconfig = qconfig
model[5].qconfig = qconfig

sparsifier = sparsity.WeightNormSparsifier(**sparse_defaults)

sparse_config = [
{
"tensor_fqn": '5.weight',
"sparsity_level": 0.7,
"sparse_block_shape": (1, 4),
"zeros_per_block": 4,
},
{"tensor_fqn": "0.weight"},
]
return model, sparsifier, sparse_config

def _squash_mask_calibrate_and_convert(self, model, sparsifier, input):
sparsifier.step()
sparsifier.squash_mask()
model(input)
tq.convert(model, inplace=True)

def _calculate_sparsity(self, tensor):
return ((tensor == 0).sum() / tensor.numel()).item()

# This test checks whether performing quantization prepare before sparse prepare
# causes any issues and verifies that the correct observers are inserted and that
# the quantized model works as expected
Expand All @@ -72,7 +71,7 @@ def test_q_prep_before_s_prep(self):
mod,
sparsifier,
sparse_config,
) = self._get_model_and_sparsifier_and_sparse_config()
) = _get_model_and_sparsifier_and_sparse_config(tq.get_default_qconfig("fbgemm"))

tq.prepare(mod, inplace=True)
sparsifier.prepare(mod, config=sparse_config)
Expand All @@ -83,7 +82,7 @@ def test_q_prep_before_s_prep(self):
# check that correct observers were inserted
self.assertTrue(hasattr(mod[5], "activation_post_process"))

self._squash_mask_calibrate_and_convert(
_squash_mask_calibrate_and_convert(
mod, sparsifier, torch.randn(1, 4, 4, 4)
)

Expand All @@ -101,7 +100,7 @@ def test_s_prep_before_q_prep(self):
mod,
sparsifier,
sparse_config,
) = self._get_model_and_sparsifier_and_sparse_config()
) = _get_model_and_sparsifier_and_sparse_config(tq.get_default_qconfig("fbgemm"))

sparsifier.prepare(mod, config=sparse_config)
tq.prepare(mod, inplace=True)
Expand All @@ -115,7 +114,7 @@ def test_s_prep_before_q_prep(self):
# occured successfully
self.assertTrue(hasattr(mod[5], "activation_post_process"))

self._squash_mask_calibrate_and_convert(
_squash_mask_calibrate_and_convert(
mod, sparsifier, torch.randn(1, 4, 4, 4)
)

Expand All @@ -132,7 +131,7 @@ def test_convert_without_squash_mask(self):
mod,
sparsifier,
sparse_config,
) = self._get_model_and_sparsifier_and_sparse_config()
) = _get_model_and_sparsifier_and_sparse_config(tq.get_default_qconfig("fbgemm"))

sparsifier.prepare(mod, config=sparse_config)
tq.prepare(mod, inplace=True)
Expand All @@ -146,7 +145,7 @@ def test_convert_without_squash_mask(self):
# occured successfully
self.assertTrue(hasattr(mod[5], "activation_post_process"))
sparsifier.step()
sparsity_level = self._calculate_sparsity(mod[5].weight)
sparsity_level = _calculate_sparsity(mod[5].weight)
mod(torch.randn(1, 4, 4, 4))
tq.convert(mod, inplace=True)

Expand All @@ -155,7 +154,7 @@ def test_convert_without_squash_mask(self):
self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))

# check that module was actually sparsified
cur_sparsity = self._calculate_sparsity(mod[5]._weight_bias()[0])
cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0])
self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
self.assertGreaterAlmostEqual(
sparsity_level, sparse_config[0]["sparsity_level"]
Expand All @@ -170,7 +169,7 @@ def test_s_prep_before_fusion(self):
mod,
sparsifier,
sparse_config,
) = self._get_model_and_sparsifier_and_sparse_config()
) = _get_model_and_sparsifier_and_sparse_config(tq.get_default_qconfig("fbgemm"))
sparsifier.prepare(mod, config=sparse_config)
tq.fuse_modules(mod, [["5", "6"]], inplace=True)
mod[5].qconfig = tq.get_default_qconfig("fbgemm")
Expand All @@ -184,7 +183,7 @@ def test_s_prep_before_fusion(self):
# check that correct observers were inserted and that matching
# occured successfully
self.assertTrue(hasattr(mod[5], "activation_post_process"))
self._squash_mask_calibrate_and_convert(
_squash_mask_calibrate_and_convert(
mod, sparsifier, torch.randn(1, 4, 4, 4)
)

Expand All @@ -199,7 +198,7 @@ def test_fusion_before_s_prep(self):
mod,
sparsifier,
_,
) = self._get_model_and_sparsifier_and_sparse_config()
) = _get_model_and_sparsifier_and_sparse_config(tq.get_default_qconfig("fbgemm"))
tq.fuse_modules(mod, [["5", "6"]], inplace=True)

# its absolutely broken by fusion but will still work if you put the correct fqn in
Expand All @@ -226,7 +225,7 @@ def test_fusion_before_s_prep(self):
# occured successfully
self.assertTrue(hasattr(mod[5], "activation_post_process"))
sparsifier.step()
sparsity_level = self._calculate_sparsity(mod[5][0].weight)
sparsity_level = _calculate_sparsity(mod[5][0].weight)
mod(torch.randn(1, 4, 4, 4))
tq.convert(mod, inplace=True)

Expand All @@ -235,7 +234,7 @@ def test_fusion_before_s_prep(self):
self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))

# check that module was actually sparsified
cur_sparsity = self._calculate_sparsity(mod[5]._weight_bias()[0])
cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0])
self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
self.assertGreaterAlmostEqual(
sparsity_level, sparse_config[0]["sparsity_level"]
Expand All @@ -251,7 +250,7 @@ def test_s_prep_before_qat_prep(self):
mod,
sparsifier,
sparse_config,
) = self._get_model_and_sparsifier_and_sparse_config(
) = _get_model_and_sparsifier_and_sparse_config(
tq.get_default_qat_qconfig("fbgemm")
)
sparsifier.prepare(mod, config=sparse_config)
Expand All @@ -263,20 +262,20 @@ def test_s_prep_before_qat_prep(self):
# occured successfully
self.assertTrue(hasattr(mod[5], "activation_post_process"))
self.assertTrue(isinstance(mod[5], torch.nn.qat.Linear))
self._squash_mask_calibrate_and_convert(
_squash_mask_calibrate_and_convert(
mod, sparsifier, torch.randn(1, 4, 4, 4)
)
# check that final module is the expected quantized module and that the model runs
self.assertTrue(isinstance(mod[5], torch.nn.quantized.Linear))
self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))

# check that module was actually sparsified
cur_sparsity = self._calculate_sparsity(mod[5]._weight_bias()[0])
cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0])
self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])

# This tests whether performing qat prepare before sparse prepare causes issues.
def test_qat_prep_before_s_prep(self):
mod, sparsifier, _ = self._get_model_and_sparsifier_and_sparse_config(
mod, sparsifier, _ = _get_model_and_sparsifier_and_sparse_config(
tq.get_default_qat_qconfig("fbgemm")
)
tq.prepare_qat(mod, inplace=True)
Expand All @@ -303,7 +302,7 @@ def test_qat_prep_before_s_prep(self):
self.assertTrue(hasattr(mod[5], "activation_post_process"))
self.assertTrue(isinstance(mod[5], torch.nn.qat.Linear))

self._squash_mask_calibrate_and_convert(
_squash_mask_calibrate_and_convert(
mod, sparsifier, torch.randn(1, 4, 4, 4)
)

Expand All @@ -312,5 +311,77 @@ def test_qat_prep_before_s_prep(self):
self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))

# check that module was actually sparsified
cur_sparsity = self._calculate_sparsity(mod[5]._weight_bias()[0])
cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0])
self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])

def _module_has_activation_post_process(model, fqn_of_module):
for node in model.graph.nodes:
# look for an observer whose arg is the target module
if "activation_post_process" in node.name:
if node.args[0].target == fqn_of_module:
return True
return False

class TestFxComposability(TestCase):
r"""This series of tests checks that various steps of the quantization and sparsity flow
compose cleanly despite variation in sequencing.
"""
def test_q_prep_fx_before_s_prep(self):
r"""
This test checks that the ordering of prepare_fx, sparse prepare and convert_fx
compose cleanly without issue and that the final result is sparsified without
having to call squash mask between sparse prepare and convert_fx. This also tests the
automatic fusion that occurs during prepare_fx.
"""
(
mod,
sparsifier,
_,
) = _get_model_and_sparsifier_and_sparse_config()

example = torch.randn(1, 4, 4, 4)
qconfig = tq.get_default_qconfig("fbgemm")
qconfig_mapping = tq.QConfigMapping() \
.set_module_name("4", qconfig) \
.set_module_name("5", qconfig)


mod = prepare_fx(mod, qconfig_mapping, (example,))

# its absolutely broken by auto fusion in fx
# but will still work if you put the correct fqn in
sparse_config = [
{
"tensor_fqn": "5.0.weight",
"sparsity_level": 0.7,
"sparse_block_shape": (1, 4),
"zeros_per_block": 4,
},
{"tensor_fqn": "0.0.weight"},
]
sparsifier.prepare(mod, config=sparse_config)

# check that correct modules had parametrizations added and
# that none were lost during prepare
self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations"))
self.assertTrue(hasattr(fqn_to_module(mod, "5.0"), "parametrizations"))

# check that correct observers were inserted and that matching
# occured successfully
self.assertTrue(_module_has_activation_post_process(mod, "5"))
sparsifier.step()
sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
mod(example)
mod = convert_fx(mod)

# check that final module is the expected quantized module and that the model runs
self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.nn.intrinsic.quantized.LinearReLU))
self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))

# check that module was actually sparsified
cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5")._weight_bias()[0])
self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
self.assertGreaterAlmostEqual(
sparsity_level, sparse_config[0]["sparsity_level"]
)
self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
1 change: 1 addition & 0 deletions test/test_ao_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

# Composability
from ao.sparsity.test_composability import TestComposability # noqa: F401
from ao.sparsity.test_composability import TestFxComposability # noqa: F401

# Utilities
from ao.sparsity.test_sparsity_utils import TestSparsityUtilFunctions # noqa: F401
Expand Down
17 changes: 11 additions & 6 deletions torch/ao/quantization/fx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
is_observed_standalone_module,
)
from ._equalize import update_obs_for_equalization, convert_eq_obs
from torch.nn.utils.parametrize import type_before_parametrizations
from .utils import (
get_custom_module_class_keys,
get_quantize_node_info,
Expand Down Expand Up @@ -461,8 +462,10 @@ def convert_weighted_module(
# root_module_to_quantized_reference_module: module mapping from root (floating point) module class
# to quantized reference module class, e.g. nn.Conv2d to nn.quantized._reference.Conv2d
root_module_to_quantized_reference_module = get_root_module_to_quantized_reference_module(backend_config_dict)
ref_qmodule_cls = root_module_to_quantized_reference_module.get(type(float_module), None)
assert ref_qmodule_cls is not None, f"No reference quantized module class configured for {type(float_module)}"
ref_qmodule_cls = root_module_to_quantized_reference_module.get(type_before_parametrizations(float_module), None)
assert (
ref_qmodule_cls is not None
), f"No reference quantized module class configured for {type_before_parametrizations(float_module)}"
ref_qmodule = ref_qmodule_cls.from_float(float_module, wq_or_wq_dict) # type: ignore[attr-defined]
if fused_module is not None:
fused_module[0] = ref_qmodule # type: ignore[operator]
Expand Down Expand Up @@ -757,16 +760,18 @@ def replace_observer_with_dequantize_node(node: Node, graph: Graph):
elif is_observed_standalone_module(modules[node.target]):
convert_standalone_module(
node, modules, model, is_reference, backend_config_dict)
elif type(modules[node.target]) in set(
# below this point `type_before_parametrizations` is used
# instead of `type` to handle situations with fx quant + sparsity
elif type_before_parametrizations(modules[node.target]) in set(
root_module_classes).union(qat_module_classes).union(fused_module_classes):
# extra check for fused module classes to make sure they are fused module classes
# of target modules
if type(modules[node.target]) in fused_module_classes and \
type(modules[node.target][0]) not in root_module_classes:
if type_before_parametrizations(modules[node.target]) in fused_module_classes and \
type_before_parametrizations(modules[node.target][0]) not in root_module_classes:
continue
convert_weighted_module(
node, modules, observed_node_names, qconfig_map, backend_config_dict)
elif type(modules[node.target]) in custom_module_classes:
elif type_before_parametrizations(modules[node.target]) in custom_module_classes:
convert_custom_module(
node, model.graph, modules, custom_module_class_mapping,
statically_quantized_custom_module_nodes)
Expand Down

0 comments on commit 8533951

Please sign in to comment.