forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
torch.ao migration: stubs.py phase 1 (pytorch#64861)
Summary: Pull Request resolved: pytorch#64861 1. move the file ``` hg mv caffe2/torch/quantization/stubs.py caffe2/torch/ao/quantization/ ``` 2. create a new file in the old location and copy the imports 3. fix all call sites inside `torch` ghstack-source-id: 137885365 Test Plan: buck test mode/dev //caffe2/test:quantization Reviewed By: jerryzh168 Differential Revision: D30879678 fbshipit-source-id: a2d24f25d01064212aca15e94e8c78240ba48953
- Loading branch information
1 parent
c08b249
commit 9d52651
Showing
5 changed files
with
85 additions
and
60 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
|
||
from torch import nn | ||
|
||
class QuantStub(nn.Module): | ||
r"""Quantize stub module, before calibration, this is same as an observer, | ||
it will be swapped as `nnq.Quantize` in `convert`. | ||
Args: | ||
qconfig: quantization configuration for the tensor, | ||
if qconfig is not provided, we will get qconfig from parent modules | ||
""" | ||
def __init__(self, qconfig=None): | ||
super(QuantStub, self).__init__() | ||
if qconfig: | ||
self.qconfig = qconfig | ||
|
||
def forward(self, x): | ||
return x | ||
|
||
|
||
class DeQuantStub(nn.Module): | ||
r"""Dequantize stub module, before calibration, this is same as identity, | ||
this will be swapped as `nnq.DeQuantize` in `convert`. | ||
""" | ||
def __init__(self): | ||
super(DeQuantStub, self).__init__() | ||
|
||
def forward(self, x): | ||
return x | ||
|
||
|
||
class QuantWrapper(nn.Module): | ||
r"""A wrapper class that wraps the input module, adds QuantStub and | ||
DeQuantStub and surround the call to module with call to quant and dequant | ||
modules. | ||
This is used by the `quantization` utility functions to add the quant and | ||
dequant modules, before `convert` function `QuantStub` will just be observer, | ||
it observes the input tensor, after `convert`, `QuantStub` | ||
will be swapped to `nnq.Quantize` which does actual quantization. Similarly | ||
for `DeQuantStub`. | ||
""" | ||
quant: QuantStub | ||
dequant: DeQuantStub | ||
module: nn.Module | ||
|
||
def __init__(self, module): | ||
super(QuantWrapper, self).__init__() | ||
qconfig = module.qconfig if hasattr(module, 'qconfig') else None | ||
self.add_module('quant', QuantStub(qconfig)) | ||
self.add_module('dequant', DeQuantStub()) | ||
self.add_module('module', module) | ||
self.train(module.training) | ||
|
||
def forward(self, X): | ||
X = self.quant(X) | ||
X = self.module(X) | ||
return self.dequant(X) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,58 +1,14 @@ | ||
|
||
from torch import nn | ||
|
||
class QuantStub(nn.Module): | ||
r"""Quantize stub module, before calibration, this is same as an observer, | ||
it will be swapped as `nnq.Quantize` in `convert`. | ||
Args: | ||
qconfig: quantization configuration for the tensor, | ||
if qconfig is not provided, we will get qconfig from parent modules | ||
""" | ||
def __init__(self, qconfig=None): | ||
super(QuantStub, self).__init__() | ||
if qconfig: | ||
self.qconfig = qconfig | ||
|
||
def forward(self, x): | ||
return x | ||
|
||
|
||
class DeQuantStub(nn.Module): | ||
r"""Dequantize stub module, before calibration, this is same as identity, | ||
this will be swapped as `nnq.DeQuantize` in `convert`. | ||
""" | ||
def __init__(self): | ||
super(DeQuantStub, self).__init__() | ||
|
||
def forward(self, x): | ||
return x | ||
|
||
|
||
class QuantWrapper(nn.Module): | ||
r"""A wrapper class that wraps the input module, adds QuantStub and | ||
DeQuantStub and surround the call to module with call to quant and dequant | ||
modules. | ||
This is used by the `quantization` utility functions to add the quant and | ||
dequant modules, before `convert` function `QuantStub` will just be observer, | ||
it observes the input tensor, after `convert`, `QuantStub` | ||
will be swapped to `nnq.Quantize` which does actual quantization. Similarly | ||
for `DeQuantStub`. | ||
""" | ||
quant: QuantStub | ||
dequant: DeQuantStub | ||
module: nn.Module | ||
|
||
def __init__(self, module): | ||
super(QuantWrapper, self).__init__() | ||
qconfig = module.qconfig if hasattr(module, 'qconfig') else None | ||
self.add_module('quant', QuantStub(qconfig)) | ||
self.add_module('dequant', DeQuantStub()) | ||
self.add_module('module', module) | ||
self.train(module.training) | ||
|
||
def forward(self, X): | ||
X = self.quant(X) | ||
X = self.module(X) | ||
return self.dequant(X) | ||
# flake8: noqa: F401 | ||
r""" | ||
This file is in the process of migration to `torch/ao/quantization`, and | ||
is kept here for compatibility while the migration process is ongoing. | ||
If you are adding a new entry/functionality, please, add it to the | ||
`torch/ao/quantization/stubs.py`, while adding an import statement | ||
here. | ||
""" | ||
|
||
from torch.ao.quantization.stubs import ( | ||
QuantStub, | ||
DeQuantStub, | ||
QuantWrapper | ||
) |