Skip to content

Commit

Permalink
[ao][sparsity][fx] make sparse prepare->quant prepare compose (pytorc…
Browse files Browse the repository at this point in the history
…h#81993)

Summary: The primary issue was that fusion and matching had to be
updated to handle parametrized modules

Test Plan: python test/test_ao_sparsity.py TestFxComposability

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: pytorch#81993
Approved by: https://github.com/jerryzh168
  • Loading branch information
HDCharles authored and pytorchmergebot committed Jul 27, 2022
1 parent d537f86 commit 8d82367
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 5 deletions.
47 changes: 46 additions & 1 deletion test/ao/sparsity/test_composability.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ class TestFxComposability(TestCase):
"""
def test_q_prep_fx_before_s_prep(self):
r"""
This test checks that the ordering of prepare_fx, sparse prepare and convert_fx
This test checks that the ordering of prepare_fx -> sparse prepare -> convert_fx
compose cleanly without issue and that the final result is sparsified without
having to call squash mask between sparse prepare and convert_fx. This also tests the
automatic fusion that occurs during prepare_fx.
Expand Down Expand Up @@ -385,3 +385,48 @@ def test_q_prep_fx_before_s_prep(self):
sparsity_level, sparse_config[0]["sparsity_level"]
)
self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])

def test_s_prep_before_q_prep_fx(self):
r"""
This test checks that the ordering of sparse prepare -> prepare_fx -> convert_fx
compose cleanly without issue and that the final result is sparsified without
having to call squash mask before convert_fx.
"""
(
mod,
sparsifier,
sparse_config,
) = _get_model_and_sparsifier_and_sparse_config()
sparsifier.prepare(mod, config=sparse_config)

example = torch.randn(1, 4, 4, 4)
qconfig = tq.get_default_qconfig("fbgemm")
qconfig_mapping = tq.QConfigMapping() \
.set_module_name("4", qconfig) \
.set_module_name("5", qconfig)
mod = prepare_fx(mod, qconfig_mapping, (example,))

# check that correct modules had parametrizations added and
# that none were lost during prepare
self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations"))
self.assertTrue(hasattr(fqn_to_module(mod, "5.0"), "parametrizations"))

# check that correct observers were inserted and that matching
# occured successfully
self.assertTrue(_module_has_activation_post_process(mod, "5"))
sparsifier.step()
sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
mod(example)
mod = convert_fx(mod)

# check that final module is the expected quantized module and that the model runs
self.assertTrue(isinstance(fqn_to_module(mod, "5"), torch.nn.intrinsic.quantized.LinearReLU))
self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))

# check that module was actually sparsified
cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5")._weight_bias()[0])
self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
self.assertGreaterAlmostEqual(
sparsity_level, sparse_config[0]["sparsity_level"]
)
self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
4 changes: 2 additions & 2 deletions torch/ao/quantization/fx/fusion_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Callable, Dict, Optional, Union, List
from .custom_config import FuseCustomConfig
from .match_utils import MatchAllNode

from torch.nn.utils.parametrize import type_before_parametrizations

__all__ = [
"DefaultFuseHandler",
Expand Down Expand Up @@ -91,7 +91,7 @@ def get_matched_types(m):
if isinstance(m, tuple):
return tuple(map(get_matched_types, m))
if isinstance(m, torch.nn.Module):
return type(m)
return type_before_parametrizations(m)
return m

matched_module_types = get_matched_types(matched_modules)
Expand Down
4 changes: 2 additions & 2 deletions torch/ao/quantization/fx/match_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .graph_module import (
is_observed_standalone_module,
)

from torch.nn.utils.parametrize import type_before_parametrizations
from typing import Any, Dict, List, Callable, Optional, Tuple, Type, Set


Expand Down Expand Up @@ -58,7 +58,7 @@ def is_match(modules, node, pattern, max_uses=sys.maxsize):
if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module):
if node.op != 'call_module':
return False
if not type(modules[node.target]) == self_match:
if not type_before_parametrizations(modules[node.target]) == self_match:
return False
elif callable(self_match):
if node.op != 'call_function' or node.target is not self_match:
Expand Down

0 comments on commit 8d82367

Please sign in to comment.