Skip to content

Commit

Permalink
Upgrade pytorch-quantization to v2.1.1
Browse files Browse the repository at this point in the history
Signed-off-by: Rajeev Rao <[email protected]>
  • Loading branch information
rajeevsrao committed Sep 22, 2021
1 parent 02bc8e3 commit 4834ab5
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 103 deletions.
2 changes: 1 addition & 1 deletion tools/pytorch-quantization/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.1.0
2.1.1
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,17 @@ def __init__(self,
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)

if quantize:
self.conv1 = quant_nn.QuantConv2d(3,
self.inplanes,
kernel_size=7,
stride=2,
padding=3,
bias=False)
else:
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)

self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
Expand All @@ -245,7 +255,11 @@ def __init__(self,
dilate=replace_stride_with_dilation[2],
quantize=quantize)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)

if quantize:
self.fc = quant_nn.QuantLinear(512 * block.expansion, num_classes)
else:
self.fc = nn.Linear(512 * block.expansion, num_classes)

for m in self.modules():
if isinstance(m, nn.Conv2d):
Expand Down Expand Up @@ -295,7 +309,8 @@ def _make_layer(self,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation,
norm_layer=norm_layer))
norm_layer=norm_layer,
quantize=quantize))

return nn.Sequential(*layers)

Expand Down
76 changes: 56 additions & 20 deletions tools/pytorch-quantization/pytorch_quantization/calib/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,19 @@ class HistogramCalibrator(_Calibrator):
num_bins: An integer. Number of histograms bins. Default 2048.
grow_method: A string. DEPRECATED. default None.
skip_zeros: A boolean. If True, skips zeros when collecting data for histogram. Default False.
torch_hist: A boolean. If True, collect histogram by torch.histc instead of np.histogram. If input tensor
is on GPU, histc will also be running on GPU. Default False.
"""
def __init__(self, num_bits, axis, unsigned, num_bins=2048, grow_method=None, skip_zeros=False):
def __init__(self, num_bits, axis, unsigned, num_bins=2048, grow_method=None, skip_zeros=False, torch_hist=False):
super(HistogramCalibrator, self).__init__(num_bits, axis, unsigned)
self._num_bins = num_bins
self._skip_zeros = skip_zeros

self._calib_bin_edges = None
self._calib_hist = None

self._torch_hist = torch_hist

if axis is not None:
raise NotImplementedError("Calibrator histogram collection only supports per tensor scaling")

Expand All @@ -68,25 +72,50 @@ def collect(self, x):
"Make sure this is the right tensor to calibrate."),
1)
x = x.abs()
x_np = x.cpu().detach().numpy()

if self._skip_zeros:
x_np = x_np[np.where(x_np != 0)]
x = x.float()

if not self._torch_hist:
x_np = x.cpu().detach().numpy()

if self._skip_zeros:
x_np = x_np[np.where(x_np != 0)]

if self._calib_bin_edges is None and self._calib_hist is None:
# first time it uses num_bins to compute histogram.
self._calib_hist, self._calib_bin_edges = np.histogram(x_np, bins=self._num_bins)
if self._calib_bin_edges is None and self._calib_hist is None:
# first time it uses num_bins to compute histogram.
self._calib_hist, self._calib_bin_edges = np.histogram(x_np, bins=self._num_bins)
else:
temp_amax = np.max(x_np)
if temp_amax > self._calib_bin_edges[-1]:
# increase the number of bins
width = self._calib_bin_edges[1] - self._calib_bin_edges[0]
# NOTE: np.arange may create an extra bin after the one containing temp_amax
new_bin_edges = np.arange(self._calib_bin_edges[-1] + width, temp_amax + width, width)
self._calib_bin_edges = np.hstack((self._calib_bin_edges, new_bin_edges))
hist, self._calib_bin_edges = np.histogram(x_np, bins=self._calib_bin_edges)
hist[:len(self._calib_hist)] += self._calib_hist
self._calib_hist = hist
else:
temp_amax = np.max(x_np)
if temp_amax > self._calib_bin_edges[-1]:
# increase the number of bins
width = self._calib_bin_edges[1] - self._calib_bin_edges[0]
# NOTE: np.arange may create an extra bin after the one containing temp_amax
new_bin_edges = np.arange(self._calib_bin_edges[-1] + width, temp_amax + width, width)
self._calib_bin_edges = np.hstack((self._calib_bin_edges, new_bin_edges))
hist, self._calib_bin_edges = np.histogram(x_np, bins=self._calib_bin_edges)
hist[:len(self._calib_hist)] += self._calib_hist
self._calib_hist = hist
# This branch of code is designed to match numpy version as close as possible
with torch.no_grad():
if self._skip_zeros:
x = x[torch.where(x != 0)]

# Because we collect histogram on absolute value, setting min=0 simplifying the rare case where
# minimum value is not exactly 0 and first batch collected has larger min value than later batches
x_max = x.max()
if self._calib_bin_edges is None and self._calib_hist is None:
self._calib_hist = torch.histc(x, bins=self._num_bins, min=0, max=x_max)
self._calib_bin_edges = torch.linspace(0, x_max, self._num_bins + 1)
else:
if x_max > self._calib_bin_edges[-1]:
width = self._calib_bin_edges[1] - self._calib_bin_edges[0]
self._num_bins = int((x_max / width).ceil().item())
self._calib_bin_edges = torch.arange(0, x_max + width, width, device=x.device)

hist = torch.histc(x, bins=self._num_bins, min=0, max=self._calib_bin_edges[-1])
hist[:self._calib_hist.numel()] += self._calib_hist
self._calib_hist = hist

def reset(self):
"""Reset the collected histogram"""
Expand All @@ -108,14 +137,21 @@ def compute_amax(
Returns:
amax: a tensor
"""
if isinstance(self._calib_hist, torch.Tensor):
calib_hist = self._calib_hist.int().cpu().numpy()
calib_bin_edges = self._calib_bin_edges.cpu().numpy()
else:
calib_hist = self._calib_hist
calib_bin_edges = self._calib_bin_edges

if method == 'entropy':
calib_amax = _compute_amax_entropy(
self._calib_hist, self._calib_bin_edges, self._num_bits, self._unsigned, stride, start_bin)
calib_hist, calib_bin_edges, self._num_bits, self._unsigned, stride, start_bin)
elif method == 'mse':
calib_amax = _compute_amax_mse(
self._calib_hist, self._calib_bin_edges, self._num_bits, self._unsigned, stride, start_bin)
calib_hist, calib_bin_edges, self._num_bits, self._unsigned, stride, start_bin)
elif method == 'percentile':
calib_amax = _compute_amax_percentile(self._calib_hist, self._calib_bin_edges, percentile)
calib_amax = _compute_amax_percentile(calib_hist, calib_bin_edges, percentile)
else:
raise TypeError("Unknown calibration method {}".format(method))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(self, num_bits=8, name=None, **kwargs):
self._scale_amax = kwargs.pop('scale_amax', None)
self._calib_method = kwargs.pop('calib_method', "max")
self._unsigned = kwargs.pop('unsigned', False)
self._narrow_range = kwargs.pop('narrow_range', True)
self._narrow_range = kwargs.pop('narrow_range', False)

if kwargs:
raise TypeError("Unused keys: {}".format(kwargs.keys()))
Expand Down
15 changes: 0 additions & 15 deletions tools/pytorch-quantization/scripts/onnx_export_per_channel.patch

This file was deleted.

57 changes: 0 additions & 57 deletions tools/pytorch-quantization/scripts/patch_onnx_export.sh

This file was deleted.

48 changes: 42 additions & 6 deletions tools/pytorch-quantization/tests/calibrator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,39 @@ def test_skip_zeros(self, verbose):
# amax should be close to 5
assert (amax - 5.).abs() < 10/2048

def test_torch_hist(self):
x_1 = torch.rand(1023, device="cuda")
x_1[0] = 0
x_2 = torch.rand(1023, device="cuda") + 1 # Make sure histogram bins need to be grown
x_2[1] = 0

calibrator_np = calib.HistogramCalibrator(8, None, False, num_bins=19, torch_hist=False)
calibrator_torch = calib.HistogramCalibrator(8, None, False, num_bins=19, torch_hist=True)

calibrator_np.collect(x_1)
calibrator_torch.collect(x_1)
assert calibrator_torch._calib_hist.numel() == calibrator_torch._calib_bin_edges.numel() - 1
np.testing.assert_array_equal(calibrator_np._calib_hist, calibrator_torch._calib_hist.cpu().numpy())
np.testing.assert_array_almost_equal(
calibrator_np._calib_bin_edges, calibrator_torch._calib_bin_edges.cpu().numpy())

# Test multiple collections with some of them needs to expand range
for _ in range(3):
calibrator_np.collect(x_2)
calibrator_torch.collect(x_2)
calibrator_np.collect(x_1)
calibrator_torch.collect(x_1)

# Test compute_amax function doesn't convert _calib_hist and _calib_bin_edges unnecessarily
calibrator_np.compute_amax("percentile", percentile=99.99)
calibrator_torch.compute_amax("percentile", percentile=99.99)

np.testing.assert_array_equal(calibrator_np._calib_hist, calibrator_torch._calib_hist.cpu().numpy())
np.testing.assert_array_almost_equal(
calibrator_np._calib_bin_edges, calibrator_torch._calib_bin_edges.cpu().numpy())
assert calibrator_torch._calib_hist.numel() == calibrator_torch._calib_bin_edges.numel() - 1


class TestEntropyCalibrator():

def test_one_tensor(self, verbose):
Expand Down Expand Up @@ -170,8 +203,9 @@ def test_unsigned(self, verbose):

assert amax < 1.1

def test_two_tensor(self, verbose):
hist_calibrator = calib.HistogramCalibrator(8, None, False, grow_method='stretch')
@pytest.mark.parametrize("torch_hist", [False, True])
def test_two_tensor(self, torch_hist, verbose):
hist_calibrator = calib.HistogramCalibrator(8, None, False, torch_hist=torch_hist)

x_2 = torch.rand(11, 7, 3, 3).cuda() # uniform in (0,1)
x_2[1, 1, 1, 1] = 10. # create outlier
Expand Down Expand Up @@ -227,8 +261,9 @@ def test_unsigned_one_tensor(self, verbose):
# amax should be closer to 512
assert (amax - 512.).abs() < (amax - 513.).abs()

def test_two_tensor(self, verbose):
calibrator = calib.HistogramCalibrator(8, None, False)
@pytest.mark.parametrize("torch_hist", [False, True])
def test_two_tensor(self, torch_hist, verbose):
calibrator = calib.HistogramCalibrator(8, None, False, torch_hist=torch_hist)

x_1 = torch.ones(11, 7, 3, 3).cuda() * 255.
x_1[1, 1, 1, 1] = 256. # create an outlier
Expand Down Expand Up @@ -278,8 +313,9 @@ def test_unsigned_one_tensor(self, verbose):
# amax should be approximately 79
assert (amax - 79.).abs() < 100/2048

def test_two_tensor(self, verbose):
calibrator = calib.HistogramCalibrator(8, None, False)
@pytest.mark.parametrize("torch_hist", [False, True])
def test_two_tensor(self, torch_hist, verbose):
calibrator = calib.HistogramCalibrator(8, None, False, torch_hist=torch_hist)

x_1 = torch.arange(100)
calibrator.collect(x_1)
Expand Down

0 comments on commit 4834ab5

Please sign in to comment.