diff --git a/test/quantization/pt2e/test_quantize_pt2e_qat.py b/test/quantization/pt2e/test_quantize_pt2e_qat.py index dfa53550fc84f..996824c0be9fd 100644 --- a/test/quantization/pt2e/test_quantize_pt2e_qat.py +++ b/test/quantization/pt2e/test_quantize_pt2e_qat.py @@ -347,11 +347,18 @@ def _verify_symmetric_xnnpack_qat_graph_helper( self.assertEqual(eps, 1e-5) -class BaseTestQuantizePT2EQAT_ConvBn(PT2EQATTestCase): +class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase): """ Base TestCase to be used for all conv-bn[-relu] fusion patterns. """ + def setUp(self): + # NB: Skip the test if this is a base class, this is to handle the test + # discovery logic in buck which finds and runs all tests here including + # the base class which we don't want to run + if self.id() and "_Base" in self.id(): + self.skipTest("Skipping test running from base class") + def test_qat_conv_no_bias(self): m1 = self._get_conv_bn_model(has_conv_bias=False, has_bn=False, has_relu=True) m2 = self._get_conv_bn_model(has_conv_bias=False, has_bn=False, has_relu=False) @@ -759,7 +766,7 @@ def test_qat_per_channel_weight_custom_dtype(self): # TODO: enable this in the next PR @skipIfNoQNNPACK -class TestQuantizePT2EQAT_ConvBn1d(BaseTestQuantizePT2EQAT_ConvBn): +class TestQuantizePT2EQAT_ConvBn1d(TestQuantizePT2EQAT_ConvBn_Base): dim = 1 example_inputs = (torch.randn(1, 3, 5),) conv_class = torch.nn.Conv1d @@ -767,7 +774,7 @@ class TestQuantizePT2EQAT_ConvBn1d(BaseTestQuantizePT2EQAT_ConvBn): @skipIfNoQNNPACK -class TestQuantizePT2EQAT_ConvBn2d(BaseTestQuantizePT2EQAT_ConvBn): +class TestQuantizePT2EQAT_ConvBn2d(TestQuantizePT2EQAT_ConvBn_Base): dim = 2 example_inputs = (torch.randn(1, 3, 5, 5),) conv_class = torch.nn.Conv2d