Skip to content

Commit

Permalink
Use new torch.ao.quantization instead of torch.quantization (pytorch#…
Browse files Browse the repository at this point in the history
…4554)

Co-authored-by: Vasilis Vryniotis <[email protected]>
  • Loading branch information
NicolasHug and datumbox authored Nov 29, 2021
1 parent 1c9ccb7 commit b3cdec1
Show file tree
Hide file tree
Showing 10 changed files with 46 additions and 46 deletions.
20 changes: 10 additions & 10 deletions references/classification/train_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time

import torch
import torch.quantization
import torch.ao.quantization
import torch.utils.data
import torchvision
import utils
Expand Down Expand Up @@ -62,8 +62,8 @@ def main(args):

if not (args.test_only or args.post_training_quantize):
model.fuse_model()
model.qconfig = torch.quantization.get_default_qat_qconfig(args.backend)
torch.quantization.prepare_qat(model, inplace=True)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
torch.ao.quantization.prepare_qat(model, inplace=True)

if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
Expand Down Expand Up @@ -96,12 +96,12 @@ def main(args):
)
model.eval()
model.fuse_model()
model.qconfig = torch.quantization.get_default_qconfig(args.backend)
torch.quantization.prepare(model, inplace=True)
model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
torch.ao.quantization.prepare(model, inplace=True)
# Calibrate first
print("Calibrating")
evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
torch.quantization.convert(model, inplace=True)
torch.ao.quantization.convert(model, inplace=True)
if args.output_dir:
print("Saving quantized model")
if utils.is_main_process():
Expand All @@ -114,8 +114,8 @@ def main(args):
evaluate(model, criterion, data_loader_test, device=device)
return

model.apply(torch.quantization.enable_observer)
model.apply(torch.quantization.enable_fake_quant)
model.apply(torch.ao.quantization.enable_observer)
model.apply(torch.ao.quantization.enable_fake_quant)
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
Expand All @@ -126,7 +126,7 @@ def main(args):
with torch.inference_mode():
if epoch >= args.num_observer_update_epochs:
print("Disabling observer for subseq epochs, epoch = ", epoch)
model.apply(torch.quantization.disable_observer)
model.apply(torch.ao.quantization.disable_observer)
if epoch >= args.num_batch_norm_update_epochs:
print("Freezing BN for subseq epochs, epoch = ", epoch)
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
Expand All @@ -136,7 +136,7 @@ def main(args):
quantized_eval_model = copy.deepcopy(model_without_ddp)
quantized_eval_model.eval()
quantized_eval_model.to(torch.device("cpu"))
torch.quantization.convert(quantized_eval_model, inplace=True)
torch.ao.quantization.convert(quantized_eval_model, inplace=True)

print("Evaluate Quantized model")
evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu"))
Expand Down
4 changes: 2 additions & 2 deletions references/classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,8 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T
# Quantized Classification
model = M.quantization.mobilenet_v3_large(pretrained=False, quantize=False)
model.fuse_model()
model.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')
_ = torch.quantization.prepare_qat(model, inplace=True)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack')
_ = torch.ao.quantization.prepare_qat(model, inplace=True)
print(store_model_weights(model, './qat.pth'))
# Object Detection
Expand Down
10 changes: 5 additions & 5 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,19 +781,19 @@ def test_quantized_classification_model(model_fn):
model = model_fn(**kwargs)
if eval_mode:
model.eval()
model.qconfig = torch.quantization.default_qconfig
model.qconfig = torch.ao.quantization.default_qconfig
else:
model.train()
model.qconfig = torch.quantization.default_qat_qconfig
model.qconfig = torch.ao.quantization.default_qat_qconfig

model.fuse_model()
if eval_mode:
torch.quantization.prepare(model, inplace=True)
torch.ao.quantization.prepare(model, inplace=True)
else:
torch.quantization.prepare_qat(model, inplace=True)
torch.ao.quantization.prepare_qat(model, inplace=True)
model.eval()

torch.quantization.convert(model, inplace=True)
torch.ao.quantization.convert(model, inplace=True)

try:
torch.jit.script(model)
Expand Down
6 changes: 3 additions & 3 deletions torchvision/models/quantization/googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def forward(self, x: Tensor) -> Tensor:
return x

def fuse_model(self) -> None:
torch.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)
torch.ao.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)


class QuantizableInception(Inception):
Expand Down Expand Up @@ -74,8 +74,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__( # type: ignore[misc]
blocks=[QuantizableBasicConv2d, QuantizableInception, QuantizableInceptionAux], *args, **kwargs
)
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
self.quant = torch.ao.quantization.QuantStub()
self.dequant = torch.ao.quantization.DeQuantStub()

def forward(self, x: Tensor) -> GoogLeNetOutputs:
x = self._transform_input(x)
Expand Down
6 changes: 3 additions & 3 deletions torchvision/models/quantization/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def forward(self, x: Tensor) -> Tensor:
return x

def fuse_model(self) -> None:
torch.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)
torch.ao.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)


class QuantizableInceptionA(inception_module.InceptionA):
Expand Down Expand Up @@ -144,8 +144,8 @@ def __init__(
QuantizableInceptionAux,
],
)
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
self.quant = torch.ao.quantization.QuantStub()
self.dequant = torch.ao.quantization.DeQuantStub()

def forward(self, x: Tensor) -> InceptionOutputs:
x = self._transform_input(x)
Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/quantization/mobilenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from torch import Tensor
from torch import nn
from torch.quantization import QuantStub, DeQuantStub, fuse_modules
from torch.ao.quantization import QuantStub, DeQuantStub, fuse_modules
from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls

from ..._internally_replaced_utils import load_state_dict_from_url
Expand Down
8 changes: 4 additions & 4 deletions torchvision/models/quantization/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
from torch import nn, Tensor
from torch.quantization import QuantStub, DeQuantStub, fuse_modules
from torch.ao.quantization import QuantStub, DeQuantStub, fuse_modules

from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops.misc import ConvNormActivation, SqueezeExcitation
Expand Down Expand Up @@ -136,13 +136,13 @@ def _mobilenet_v3_model(
backend = "qnnpack"

model.fuse_model()
model.qconfig = torch.quantization.get_default_qat_qconfig(backend)
torch.quantization.prepare_qat(model, inplace=True)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend)
torch.ao.quantization.prepare_qat(model, inplace=True)

if pretrained:
_load_weights(arch, model, quant_model_urls.get(arch + "_" + backend, None), progress)

torch.quantization.convert(model, inplace=True)
torch.ao.quantization.convert(model, inplace=True)
model.eval()
else:
if pretrained:
Expand Down
12 changes: 6 additions & 6 deletions torchvision/models/quantization/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import torch.nn as nn
from torch import Tensor
from torch.quantization import fuse_modules
from torch.ao.quantization import fuse_modules
from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls

from ..._internally_replaced_utils import load_state_dict_from_url
Expand Down Expand Up @@ -42,9 +42,9 @@ def forward(self, x: Tensor) -> Tensor:
return out

def fuse_model(self) -> None:
torch.quantization.fuse_modules(self, [["conv1", "bn1", "relu"], ["conv2", "bn2"]], inplace=True)
torch.ao.quantization.fuse_modules(self, [["conv1", "bn1", "relu"], ["conv2", "bn2"]], inplace=True)
if self.downsample:
torch.quantization.fuse_modules(self.downsample, ["0", "1"], inplace=True)
torch.ao.quantization.fuse_modules(self.downsample, ["0", "1"], inplace=True)


class QuantizableBottleneck(Bottleneck):
Expand Down Expand Up @@ -75,15 +75,15 @@ def forward(self, x: Tensor) -> Tensor:
def fuse_model(self) -> None:
fuse_modules(self, [["conv1", "bn1", "relu1"], ["conv2", "bn2", "relu2"], ["conv3", "bn3"]], inplace=True)
if self.downsample:
torch.quantization.fuse_modules(self.downsample, ["0", "1"], inplace=True)
torch.ao.quantization.fuse_modules(self.downsample, ["0", "1"], inplace=True)


class QuantizableResNet(ResNet):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
self.quant = torch.ao.quantization.QuantStub()
self.dequant = torch.ao.quantization.DeQuantStub()

def forward(self, x: Tensor) -> Tensor:
x = self.quant(x)
Expand Down
10 changes: 5 additions & 5 deletions torchvision/models/quantization/shufflenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2):
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, inverted_residual=QuantizableInvertedResidual, **kwargs) # type: ignore[misc]
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
self.quant = torch.ao.quantization.QuantStub()
self.dequant = torch.ao.quantization.DeQuantStub()

def forward(self, x: Tensor) -> Tensor:
x = self.quant(x)
Expand All @@ -60,12 +60,12 @@ def fuse_model(self) -> None:

for name, m in self._modules.items():
if name in ["conv1", "conv5"]:
torch.quantization.fuse_modules(m, [["0", "1", "2"]], inplace=True)
torch.ao.quantization.fuse_modules(m, [["0", "1", "2"]], inplace=True)
for m in self.modules():
if type(m) is QuantizableInvertedResidual:
if len(m.branch1._modules.items()) > 0:
torch.quantization.fuse_modules(m.branch1, [["0", "1"], ["2", "3", "4"]], inplace=True)
torch.quantization.fuse_modules(
torch.ao.quantization.fuse_modules(m.branch1, [["0", "1"], ["2", "3", "4"]], inplace=True)
torch.ao.quantization.fuse_modules(
m.branch2,
[["0", "1", "2"], ["3", "4"], ["5", "6", "7"]],
inplace=True,
Expand Down
14 changes: 7 additions & 7 deletions torchvision/models/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,19 @@ def quantize_model(model: nn.Module, backend: str) -> None:
model.eval()
# Make sure that weight qconfig matches that of the serialized models
if backend == "fbgemm":
model.qconfig = torch.quantization.QConfig( # type: ignore[assignment]
activation=torch.quantization.default_observer,
weight=torch.quantization.default_per_channel_weight_observer,
model.qconfig = torch.ao.quantization.QConfig( # type: ignore[assignment]
activation=torch.ao.quantization.default_observer,
weight=torch.ao.quantization.default_per_channel_weight_observer,
)
elif backend == "qnnpack":
model.qconfig = torch.quantization.QConfig( # type: ignore[assignment]
activation=torch.quantization.default_observer, weight=torch.quantization.default_weight_observer
model.qconfig = torch.ao.quantization.QConfig( # type: ignore[assignment]
activation=torch.ao.quantization.default_observer, weight=torch.ao.quantization.default_weight_observer
)

# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
model.fuse_model() # type: ignore[operator]
torch.quantization.prepare(model, inplace=True)
torch.ao.quantization.prepare(model, inplace=True)
model(_dummy_input_data)
torch.quantization.convert(model, inplace=True)
torch.ao.quantization.convert(model, inplace=True)

return

0 comments on commit b3cdec1

Please sign in to comment.