Skip to content

Commit

Permalink
FixedQParamsFakeQuantize: adjust default quant_min and quant_max (pyt…
Browse files Browse the repository at this point in the history
…orch#47423)

Summary:
Pull Request resolved: pytorch#47423

Since the dtype of this fake_quant is `quint8`, the output range should be
from 0 to 255.  Fixing.  This should address the numerical inaccuracies with
sigmoid and hardsigmoid with `FixedQParamsFakeQuantize` attached compared
to their quantized counterparts.

In a future PR, might be safer to also make the activation functions
using `FixedQParamsFakeQuantize` to explicitly specify their expected
output range and zero_point.  Leaving that for later, as this bugfix
should be landed urgently.

Test Plan:
Manual script which gives low SQNR before this PR and high SQNR after
this PR: https://gist.github.com/vkuzo/9906bae29223da72b10d6b6aafadba42

pytorch#47376, which can be landed after
this, adds a proper test.

Imported from OSS

Reviewed By: ayush29feb, jerryzh168

Differential Revision: D24751497

fbshipit-source-id: 4c32e22a30116caaceeedb4cd47146d066054a89
  • Loading branch information
vkuzo authored and facebook-github-bot committed Nov 5, 2020
1 parent 745899f commit 5977d1d
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
26 changes: 26 additions & 0 deletions test/quantization/test_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,6 +1250,28 @@ def test_leaky_relu(self):


class TestEagerModeQATOps(QuantizationTestCase):
def _test_activation_impl(self, Act, data):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.act = Act()
self.quant = QuantStub()
self.dequant = DeQuantStub()

def forward(self, x):
x = self.quant(x)
x = self.act(x)
x = self.dequant(x)
return x

m = M().train()
m.qconfig = default_qat_qconfig
m = prepare_qat(m)
before_convert = m(data)
m = convert(m)
after_convert = m(data)
self.assertEqual(before_convert, after_convert)

def test_fixed_qparam_ops(self):
class M(torch.nn.Module):
def __init__(self):
Expand All @@ -1273,7 +1295,11 @@ def forward(self, x):
m = prepare_qat(m)
for attr in ['sigmoid', 'hardsigmoid', 'tanh']:
self.assertEqual(type(getattr(m, attr).activation_post_process), FixedQParamsFakeQuantize)
data = torch.randn(1, 3, 2, 4)
before_convert = m(data)
m = convert(m)
after_convert = m(data)
self.assertEqual(before_convert, after_convert)
# make sure activation post process is removed
for attr in ['sigmoid', 'hardsigmoid', 'tanh']:
# verify fake quant module is removd
Expand Down
9 changes: 6 additions & 3 deletions torch/quantization/fake_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ class FixedQParamsFakeQuantize(FakeQuantizeBase):
def __init__(self,
scale,
zero_point,
dtype,
dtype=torch.quint8,
qscheme=torch.per_tensor_affine,
quant_min=0,
quant_max=255):
Expand Down Expand Up @@ -243,10 +243,13 @@ def extra_repr(self):
dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True)
default_weight_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=-128, quant_max=127,
dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False)

# TODO(future PR): remove these defaults and enforce activation functions
# to explicitly specify their output range
default_symmetric_fixed_qparams_fake_quant = FixedQParamsFakeQuantize.with_args(
scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8)
scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8, quant_min=0, quant_max=255)
default_affine_fixed_qparams_fake_quant = FixedQParamsFakeQuantize.with_args(
scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8, quant_min=-128, quant_max=127)
scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8, quant_min=0, quant_max=255)

default_per_channel_weight_fake_quant = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
quant_min=-128,
Expand Down

0 comments on commit 5977d1d

Please sign in to comment.