Skip to content

Commit

Permalink
Skip quantization tests running from BaseTestQuantizePT2EQAT_ConvBn (p…
Browse files Browse the repository at this point in the history
…ytorch#114829)

Summary: This is a follow-up from D51428979.  These tests should be run only from `TestQuantizePT2EQAT_ConvBn1d` and `TestQuantizePT2EQAT_ConvBn2d`. The base class doesn't have the necessary setup to run them and will fail expectedly.  I previously ignored the failures on D51428979, and these failed tests have been disabled.

Test Plan:
Run an example test there and confirm that two versions from `TestQuantizePT2EQAT_ConvBn1d` and `TestQuantizePT2EQAT_ConvBn2d` are run while the one from `BaseTestQuantizePT2EQAT_ConvBn` is skipped

```
$ buck2 test 'fbcode//mode/opt' fbcode//caffe2/test/quantization:test_quantization -- --run-disabled 'caffe2/test/quantization:test_quantization - test_qat_conv_bn_fusion_literal_args'
File changed: fbcode//caffe2/test/quantization/pt2e/test_quantize_pt2e_qat.py
↷ Skip: caffe2/test/quantization:test_quantization - test_qat_conv_bn_fusion_literal_args (caffe2.test.quantization.pt2e.test_quantize_pt2e_qat.BaseTestQuantizePT2EQAT_ConvBn) (0.0s)

/data/users/huydo/fbsource/buck-out/v2/gen/fbcode/689edf96bfbb5738/caffe2/test/quantization/__test_quantization__/test_quantization#link-tree/torch/_utils_internal.py:230: NCCL_DEBUG env var is set to None
/data/users/huydo/fbsource/buck-out/v2/gen/fbcode/689edf96bfbb5738/caffe2/test/quantization/__test_quantization__/test_quantization#link-tree/torch/_utils_internal.py:239: NCCL_DEBUG is WARN from /etc/nccl.conf
INFO:2023-11-29 19:20:33 3049620:3049620 CuptiActivityProfiler.cpp:225] CUDA versions. CUPTI: 18; Runtime: 12000; Driver: 12000
/data/users/huydo/fbsource/buck-out/v2/gen/fbcode/689edf96bfbb5738/caffe2/test/quantization/__test_quantization__/test_quantization#link-tree/torch/_utils_internal.py:158: DeprecationWarning: This is a NOOP in python >= 3.7, its just too dangerous with how we write code at facebook. Instead we patch os.fork and multiprocessing which can raise exceptions if a deadlock would happen.
  threadSafeForkRegisterAtFork()
test_qat_conv_bn_fusion_literal_args (caffe2.test.quantization.pt2e.test_quantize_pt2e_qat.BaseTestQuantizePT2EQAT_ConvBn) ... skipped 'Skipping test running from BaseTestQuantizePT2EQAT_ConvBn'

----------------------------------------------------------------------
Ran 1 test in 0.001s

OK (skipped=1)

Skipped: Skipping test running from BaseTestQuantizePT2EQAT_ConvBn

Buck UI: https://www.internalfb.com/buck2/7b70fb33-44cb-4745-92e1-64031bb413b8
Test UI: https://www.internalfb.com/intern/testinfra/testrun/6473924660765251
Network: Up: 12KiB  Down: 0B  (reSessionID-0399f0c3-e671-4770-a41c-75c06ae709d5)
Jobs completed: 11. Time elapsed: 1:07.2s.
Cache hits: 0%. Commands: 1 (cached: 0, remote: 0, local: 1)
Tests finished: Pass 2. Fail 0. Fatal 0. Skip 1. Build failure 0
```

Differential Revision: D51694959

Pull Request resolved: pytorch#114829
Approved by: https://github.com/clee2000
  • Loading branch information
huydhn authored and pytorchmergebot committed Dec 1, 2023
1 parent d6c0d1b commit 5687285
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions test/quantization/pt2e/test_quantize_pt2e_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,11 +347,18 @@ def _verify_symmetric_xnnpack_qat_graph_helper(
self.assertEqual(eps, 1e-5)


class BaseTestQuantizePT2EQAT_ConvBn(PT2EQATTestCase):
class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
"""
Base TestCase to be used for all conv-bn[-relu] fusion patterns.
"""

def setUp(self):
# NB: Skip the test if this is a base class, this is to handle the test
# discovery logic in buck which finds and runs all tests here including
# the base class which we don't want to run
if self.id() and "_Base" in self.id():
self.skipTest("Skipping test running from base class")

def test_qat_conv_no_bias(self):
m1 = self._get_conv_bn_model(has_conv_bias=False, has_bn=False, has_relu=True)
m2 = self._get_conv_bn_model(has_conv_bias=False, has_bn=False, has_relu=False)
Expand Down Expand Up @@ -759,15 +766,15 @@ def test_qat_per_channel_weight_custom_dtype(self):

# TODO: enable this in the next PR
@skipIfNoQNNPACK
class TestQuantizePT2EQAT_ConvBn1d(BaseTestQuantizePT2EQAT_ConvBn):
class TestQuantizePT2EQAT_ConvBn1d(TestQuantizePT2EQAT_ConvBn_Base):
dim = 1
example_inputs = (torch.randn(1, 3, 5),)
conv_class = torch.nn.Conv1d
bn_class = torch.nn.BatchNorm1d


@skipIfNoQNNPACK
class TestQuantizePT2EQAT_ConvBn2d(BaseTestQuantizePT2EQAT_ConvBn):
class TestQuantizePT2EQAT_ConvBn2d(TestQuantizePT2EQAT_ConvBn_Base):
dim = 2
example_inputs = (torch.randn(1, 3, 5, 5),)
conv_class = torch.nn.Conv2d
Expand Down

0 comments on commit 5687285

Please sign in to comment.