Skip to content

Commit

Permalink
[quant] Remove deprecated torch.jit.quantized APIs (pytorch#118406)
Browse files Browse the repository at this point in the history
The `torch.jit.quantized` interface has been deprecated since pytorch#40102 (June 2020).

BC-breaking message:

All functions and classes under `torch.jit.quantized` will now raise an error if
called/instantiated. This API has long been deprecated in favor of
`torch.ao.nn.quantized`.

Pull Request resolved: pytorch#118406
Approved by: https://github.com/jerryzh168
  • Loading branch information
peterbell10 authored and pytorchmergebot committed Jan 27, 2024
1 parent d03173e commit 1460334
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 907 deletions.
46 changes: 8 additions & 38 deletions test/jit/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA
from torch.testing._internal.common_utils import slowTest, suppress_warnings
from torch.testing._internal.common_quantization import skipIfNoFBGEMM

if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
Expand Down Expand Up @@ -305,7 +304,7 @@ def test_reinforcement_learning_cuda(self):
self._test_reinforcement_learning(self, device='cuda', test_export_import=False)

@staticmethod
def _test_snli(self, device, check_export_import=True, quantized=False):
def _test_snli(self, device, check_export_import=True):
class Bottle(nn.Module):

def forward(self, input):
Expand Down Expand Up @@ -392,27 +391,13 @@ class Config:
premise = torch.LongTensor(48, 64).random_(0, 100).to(device)
hypothesis = torch.LongTensor(24, 64).random_(0, 100).to(device)

if quantized:
snli = SNLIClassifier(Config()).cpu()
torch.jit.quantized.quantize_linear_modules(snli)
# we don't do export/import checks because we would need to call
# _pack/_unpack
self.checkTrace(snli, (premise, hypothesis), inputs_require_grads=False,
export_import=False)
else:
self.checkTrace(SNLIClassifier(Config()).to(device), (premise, hypothesis),
inputs_require_grads=False, export_import=check_export_import)
self.checkTrace(SNLIClassifier(Config()).to(device), (premise, hypothesis),
inputs_require_grads=False, export_import=check_export_import)

@slowTest
def test_snli(self):
self._test_snli(self, device='cpu')

@skipIfNoFBGEMM
# Suppression: this exercises a deprecated API
@suppress_warnings
def test_snli_quantized(self):
self._test_snli(self, device='cpu', quantized=True)

@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_snli_cuda(self):
# XXX: export_import on CUDA modules doesn't work (#11480)
Expand Down Expand Up @@ -504,7 +489,7 @@ def forward(self, input):
export_import=False)

@staticmethod
def _test_vae(self, device, check_export_import=True, quantized=False):
def _test_vae(self, device, check_export_import=True):
class VAE(nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -536,29 +521,14 @@ def forward(self, x):
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar

if quantized:
vae = VAE().to(device).eval()
torch.jit.quantized.quantize_linear_modules(vae)
# We don't do export/import checks because we would need to call
# _unpack and _pack
self.checkTrace(vae, (torch.rand(128, 1, 28, 28, device=device),),
export_import=False, allow_unused=True,
inputs_require_grads=False)
else:
with enable_profiling_mode_for_profiling_tests():
# eval() is present because randn_like makes this nondeterministic
self.checkTrace(VAE().to(device).eval(), (torch.rand(128, 1, 28, 28, device=device),),
export_import=check_export_import)
with enable_profiling_mode_for_profiling_tests():
# eval() is present because randn_like makes this nondeterministic
self.checkTrace(VAE().to(device).eval(), (torch.rand(128, 1, 28, 28, device=device),),
export_import=check_export_import)

def test_vae(self):
self._test_vae(self, device='cpu')

@skipIfNoFBGEMM
# Suppression: this exercises a deprecated API
@suppress_warnings
def test_vae_quantized(self):
self._test_vae(self, device='cpu', quantized=True)

@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_vae_cuda(self):
# XXX: export_import on CUDA modules doesn't work (#11480)
Expand Down
147 changes: 10 additions & 137 deletions test/quantization/jit/test_deprecated_jit_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@
from torch.testing._internal.common_quantization import (
skipIfNoFBGEMM
)
from torch.testing._internal.common_utils import suppress_warnings
from torch.testing._internal.jit_utils import JitTestCase

from typing import Tuple
import copy

class TestDeprecatedJitQuantized(JitTestCase):
@skipIfNoFBGEMM
Expand Down Expand Up @@ -54,54 +51,8 @@ def test_rnn_cell_quantized(self):
torch.tensor(vals, dtype=torch.float),
requires_grad=False)

ref = copy.deepcopy(cell)

cell = torch.jit.quantized.quantize_rnn_cell_modules(cell)
x = torch.tensor([[100, -155],
[-155, 100],
[100, -155]], dtype=torch.float)
h0_vals = [[-155, 100],
[-155, 155],
[100, -155]]
hx = torch.tensor(h0_vals, dtype=torch.float)
if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell):
cx = torch.tensor(h0_vals, dtype=torch.float)
hiddens = (hx, cx)
else:
hiddens = hx

if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell):
class ScriptWrapper(torch.jit.ScriptModule):
def __init__(self, cell):
super().__init__()
self.cell = cell

@torch.jit.script_method
def forward(self, x: torch.Tensor,
hiddens: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
return self.cell(x, hiddens)
else:

class ScriptWrapper(torch.jit.ScriptModule):
def __init__(self, cell):
super().__init__()
self.cell = cell

@torch.jit.script_method
def forward(self, x: torch.Tensor, hiddens: torch.Tensor) -> torch.Tensor:
return self.cell(x, hiddens)

cell = ScriptWrapper(cell)
outs = cell(x, hiddens)
cell = self.getExportImportCopyWithPacking(cell)

outs = cell(x, hiddens)
ref_outs = ref(x, hiddens)

self.assertEqual(len(outs), len(ref_outs))
for out, ref_out in zip(outs, ref_outs):
torch.testing.assert_close(out, ref_out)
with self.assertRaisesRegex(RuntimeError, "quantize_rnn_cell_modules function is no longer supported"):
cell = torch.jit.quantized.quantize_rnn_cell_modules(cell)

@skipIfNoFBGEMM
def test_rnn_quantized(self):
Expand Down Expand Up @@ -143,85 +94,14 @@ def test_rnn_quantized(self):
torch.tensor(vals, dtype=torch.float),
requires_grad=False)

ref = copy.deepcopy(cell)
cell_int8 = torch.jit.quantized.quantize_rnn_modules(cell, dtype=torch.int8)
cell_fp16 = torch.jit.quantized.quantize_rnn_modules(cell, dtype=torch.float16)

niter = 10
x = torch.tensor([[100, -155],
[-155, 100],
[100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1)
h0_vals = [[-155, 100],
[-155, 155],
[100, -155]]
hx = torch.tensor(h0_vals, dtype=torch.float).unsqueeze(0)
cx = torch.tensor(h0_vals, dtype=torch.float).unsqueeze(0)

if isinstance(ref, torch.nn.LSTM):
hiddens = (hx, cx)
elif isinstance(ref, torch.nn.GRU):
hiddens = hx

ref_out, ref_hid = ref(x, hiddens)
with self.assertRaisesRegex(RuntimeError, "quantize_rnn_modules function is no longer supported"):
cell_int8 = torch.jit.quantized.quantize_rnn_modules(cell, dtype=torch.int8)

# Compare int8 quantized to unquantized
output_int8, final_hiddens_int8 = cell_int8(x, hiddens)
with self.assertRaisesRegex(RuntimeError, "quantize_rnn_modules function is no longer supported"):
cell_fp16 = torch.jit.quantized.quantize_rnn_modules(cell, dtype=torch.float16)

torch.testing.assert_close(output_int8, ref_out)
for out, ref in zip(final_hiddens_int8, ref_hid):
torch.testing.assert_close(out, ref)

# Compare fp16 quantized to unquantized
output_fp16, final_hiddens_fp16 = cell_fp16(x, hiddens)

torch.testing.assert_close(output_fp16, ref_out)
for out, ref in zip(final_hiddens_fp16, ref_hid):
torch.testing.assert_close(out, ref)

def compare_quantized_unquantized(ScriptWrapper, cell):
wrapper = ScriptWrapper(cell)

# Compare quantize scripted module to unquantized
script_out, script_hid = wrapper(x, hiddens)
torch.testing.assert_close(script_out, ref_out)
for out, ref in zip(script_hid, ref_hid):
torch.testing.assert_close(out, ref)

# Compare export/import to unquantized
export_import_wrapper = self.getExportImportCopyWithPacking(wrapper)
ei_out, ei_hid = export_import_wrapper(x, hiddens)
torch.testing.assert_close(ei_out, ref_out)
for out, ref in zip(ei_hid, ref_hid):
torch.testing.assert_close(out, ref)

if isinstance(cell, torch.jit.quantized.QuantizedGRU):
class ScriptWrapper(torch.jit.ScriptModule):
def __init__(self, cell):
super().__init__()
self.cell = cell

@torch.jit.script_method
def forward(self, x: torch.Tensor, hiddens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return self.cell(x, hiddens)

compare_quantized_unquantized(ScriptWrapper, cell)
elif isinstance(cell, torch.jit.quantized.QuantizedLSTM):
for cell in [cell_int8, cell_fp16]:
class ScriptWrapper(torch.jit.ScriptModule):
def __init__(self, cell):
super().__init__()
self.cell = cell

@torch.jit.script_method
def forward(self, x, hiddens):
# type: (torch.Tensor, Tuple[torch.Tensor, torch.Tensor])
# -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
return self.cell(x, hiddens)
compare_quantized_unquantized(ScriptWrapper, cell)

if 'fbgemm' in torch.backends.quantized.supported_engines:
# Suppression: using deprecated quant api
@suppress_warnings
def test_quantization_modules(self):
K1, N1 = 2, 2

Expand All @@ -244,18 +124,11 @@ def forward(self, x):

y_ref = fb(value)

fb_int8 = torch.jit.quantized.quantize_linear_modules(fb)
traced_int8 = torch.jit.trace(fb_int8, (x,))
fb_int8 = self.getExportImportCopyWithPacking(traced_int8)
y_int8 = fb_int8(value)

fb_fp16 = torch.jit.quantized.quantize_linear_modules(fb, torch.float16)
traced_fp16 = torch.jit.trace(fb_fp16, (x,))
fb_fp16 = self.getExportImportCopyWithPacking(traced_fp16)
y_fp16 = fb_fp16(value)
with self.assertRaisesRegex(RuntimeError, "quantize_linear_modules function is no longer supported"):
fb_int8 = torch.jit.quantized.quantize_linear_modules(fb)

torch.testing.assert_close(y_int8, y_ref, rtol=0.0001, atol=1e-3)
torch.testing.assert_close(y_fp16, y_ref, rtol=0.0001, atol=1e-3)
with self.assertRaisesRegex(RuntimeError, "quantize_linear_modules function is no longer supported"):
fb_fp16 = torch.jit.quantized.quantize_linear_modules(fb, torch.float16)

@skipIfNoFBGEMM
def test_erase_class_tensor_shapes(self):
Expand Down
Loading

0 comments on commit 1460334

Please sign in to comment.