Skip to content

Commit

Permalink
[Fix] mse observer uses per_channel affine (ModelTC#150)
Browse files Browse the repository at this point in the history
(cherry picked from commit d4abf393b017ecb9b470a9824d2a5c0397f082e8)

Co-authored-by: fanyunqian <[email protected]>
  • Loading branch information
PannenetsF and fanyunqian authored Aug 10, 2022
1 parent c57a9a2 commit 9335ab3
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
4 changes: 1 addition & 3 deletions mqbench/fake_quantize/adaround_quantizer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import torch
from torch.nn.parameter import Parameter

from mqbench.fake_quantize.quantize_base import QuantizeBase
from mqbench.fake_quantize.quantize_base import QuantizeBase, _version_under_1100
from mqbench.utils.hook import PerChannelLoadHook

_version_under_1100 = int(torch.__version__.split('.')[1]) < 10

def _rectified_sigmoid(alpha, zeta, gamma):
"""Function to generate rounding mask.
Expand Down
1 change: 1 addition & 0 deletions mqbench/fake_quantize/quantize_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from mqbench.utils import is_symmetric_quant

_version_under_1100 = int(torch.__version__.split('.')[1]) < 10

class QuantizeBase(FakeQuantizeBase):
r""" This is an extension of the FakeQuantize module in fake_quantize.py, which
Expand Down
5 changes: 3 additions & 2 deletions mqbench/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from torch.quantization.observer import _ObserverBase

from mqbench.fake_quantize.quantize_base import _version_under_1100
from mqbench.utils import sync_tensor, pot_quantization, is_symmetric_quant
from mqbench.utils.logger import logger
from mqbench.utils.hook import PerChannelLoadHook
Expand Down Expand Up @@ -523,7 +524,7 @@ def mse_perchannel(self, x: torch.Tensor, x_min: torch.Tensor, x_max: torch.Tens
new_max = x_max * (1.0 - (i * 0.01))
scale, zero_point = self._calculate_qparams(new_min, new_max)
x_q = torch.fake_quantize_per_channel_affine(
x, scale, zero_point.long(), ch_axis,
x, scale, zero_point.long() if _version_under_1100 else zero_point, ch_axis,
self.quant_min, self.quant_max)
score = self.lp_loss(x_q, x, reduce_dim)
update_idx = (score < best_score)
Expand Down Expand Up @@ -602,7 +603,7 @@ def mse_perchannel(self, x: torch.Tensor, x_min: torch.Tensor, x_max: torch.Tens
new_max = x_max * (1.0 - (i * 0.01))
scale, zero_point = self._calculate_qparams(new_min, new_max)
x_q = torch.fake_quantize_per_channel_affine(
x, scale, zero_point.long(), ch_axis,
x, scale, zero_point.long() if _version_under_1100 else zero_point, ch_axis,
self.quant_min, self.quant_max)
score = self.lp_loss(x_q, x, reduce_dim)
update_idx = (score < best_score)
Expand Down

0 comments on commit 9335ab3

Please sign in to comment.