Skip to content

Commit

Permalink
torch.ao migration: stubs.py phase 1 (pytorch#64861)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#64861

1. move the file
  ```
  hg mv caffe2/torch/quantization/stubs.py caffe2/torch/ao/quantization/
  ```

  2. create a new file in the old location and copy the imports
  3. fix all call sites inside `torch`
ghstack-source-id: 137885365

Test Plan: buck test mode/dev //caffe2/test:quantization

Reviewed By: jerryzh168

Differential Revision: D30879678

fbshipit-source-id: a2d24f25d01064212aca15e94e8c78240ba48953
  • Loading branch information
supriyar authored and facebook-github-bot committed Sep 13, 2021
1 parent c08b249 commit 9d52651
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 60 deletions.
11 changes: 11 additions & 0 deletions test/quantization/ao_migration/test_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,14 @@ def test_function_import(self):
'swap_module',
]
self._test_function_import('quantize', function_list)

def test_package_import_stubs(self):
self._test_package_import('stubs')

def test_function_import_stubs(self):
function_list = [
'QuantStub',
'DeQuantStub',
'QuantWrapper',
]
self._test_function_import('stubs', function_list)
2 changes: 1 addition & 1 deletion torch/ao/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
_get_special_act_post_process,
)

from torch.quantization.stubs import DeQuantStub, QuantWrapper
from torch.ao.quantization.stubs import DeQuantStub, QuantWrapper
from torch.quantization.qconfig import (
add_module_to_qconfig_obs_ctr,
default_dynamic_qconfig,
Expand Down
58 changes: 58 additions & 0 deletions torch/ao/quantization/stubs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@

from torch import nn

class QuantStub(nn.Module):
r"""Quantize stub module, before calibration, this is same as an observer,
it will be swapped as `nnq.Quantize` in `convert`.
Args:
qconfig: quantization configuration for the tensor,
if qconfig is not provided, we will get qconfig from parent modules
"""
def __init__(self, qconfig=None):
super(QuantStub, self).__init__()
if qconfig:
self.qconfig = qconfig

def forward(self, x):
return x


class DeQuantStub(nn.Module):
r"""Dequantize stub module, before calibration, this is same as identity,
this will be swapped as `nnq.DeQuantize` in `convert`.
"""
def __init__(self):
super(DeQuantStub, self).__init__()

def forward(self, x):
return x


class QuantWrapper(nn.Module):
r"""A wrapper class that wraps the input module, adds QuantStub and
DeQuantStub and surround the call to module with call to quant and dequant
modules.
This is used by the `quantization` utility functions to add the quant and
dequant modules, before `convert` function `QuantStub` will just be observer,
it observes the input tensor, after `convert`, `QuantStub`
will be swapped to `nnq.Quantize` which does actual quantization. Similarly
for `DeQuantStub`.
"""
quant: QuantStub
dequant: DeQuantStub
module: nn.Module

def __init__(self, module):
super(QuantWrapper, self).__init__()
qconfig = module.qconfig if hasattr(module, 'qconfig') else None
self.add_module('quant', QuantStub(qconfig))
self.add_module('dequant', DeQuantStub())
self.add_module('module', module)
self.train(module.training)

def forward(self, X):
X = self.quant(X)
X = self.module(X)
return self.dequant(X)
2 changes: 1 addition & 1 deletion torch/quantization/quantization_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from typing import Optional, Union, Dict, Set, Callable, Any

from .stubs import QuantStub, DeQuantStub
from torch.ao.quantization.stubs import QuantStub, DeQuantStub
from .fake_quantize import (
default_affine_fixed_qparams_fake_quant,
default_symmetric_fixed_qparams_fake_quant,
Expand Down
72 changes: 14 additions & 58 deletions torch/quantization/stubs.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,14 @@

from torch import nn

class QuantStub(nn.Module):
r"""Quantize stub module, before calibration, this is same as an observer,
it will be swapped as `nnq.Quantize` in `convert`.
Args:
qconfig: quantization configuration for the tensor,
if qconfig is not provided, we will get qconfig from parent modules
"""
def __init__(self, qconfig=None):
super(QuantStub, self).__init__()
if qconfig:
self.qconfig = qconfig

def forward(self, x):
return x


class DeQuantStub(nn.Module):
r"""Dequantize stub module, before calibration, this is same as identity,
this will be swapped as `nnq.DeQuantize` in `convert`.
"""
def __init__(self):
super(DeQuantStub, self).__init__()

def forward(self, x):
return x


class QuantWrapper(nn.Module):
r"""A wrapper class that wraps the input module, adds QuantStub and
DeQuantStub and surround the call to module with call to quant and dequant
modules.
This is used by the `quantization` utility functions to add the quant and
dequant modules, before `convert` function `QuantStub` will just be observer,
it observes the input tensor, after `convert`, `QuantStub`
will be swapped to `nnq.Quantize` which does actual quantization. Similarly
for `DeQuantStub`.
"""
quant: QuantStub
dequant: DeQuantStub
module: nn.Module

def __init__(self, module):
super(QuantWrapper, self).__init__()
qconfig = module.qconfig if hasattr(module, 'qconfig') else None
self.add_module('quant', QuantStub(qconfig))
self.add_module('dequant', DeQuantStub())
self.add_module('module', module)
self.train(module.training)

def forward(self, X):
X = self.quant(X)
X = self.module(X)
return self.dequant(X)
# flake8: noqa: F401
r"""
This file is in the process of migration to `torch/ao/quantization`, and
is kept here for compatibility while the migration process is ongoing.
If you are adding a new entry/functionality, please, add it to the
`torch/ao/quantization/stubs.py`, while adding an import statement
here.
"""

from torch.ao.quantization.stubs import (
QuantStub,
DeQuantStub,
QuantWrapper
)

0 comments on commit 9d52651

Please sign in to comment.