Skip to content

Commit

Permalink
Add QuantizedParallelLinear & Update Uniform (PaddlePaddle#1694)
Browse files Browse the repository at this point in the history
  • Loading branch information
RachelXu7 authored Mar 22, 2023
1 parent 3b2ed2c commit e2b5c16
Show file tree
Hide file tree
Showing 7 changed files with 743 additions and 5 deletions.
17 changes: 17 additions & 0 deletions paddleslim/quant/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .parallel_linear import QuantizedColumnParallelLinear, QuantizedRowParallelLinear

__all__ = ["QuantizedColumnParallelLinear", "QuantizedRowParallelLinear"]
140 changes: 140 additions & 0 deletions paddleslim/quant/layers/parallel_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle.nn import Layer
from paddle.nn import functional as F

from paddle.nn.quant.format import ConvertibleQuantedLayer


class QuantizedRowParallelLinear(ConvertibleQuantedLayer):
"""
The computational logic of QuantizedRowParallelLinear is the same as RowParallelLinear.
The only difference is that its inputs are all fake quantized.
"""

def __init__(self, layer: Layer, q_config):
super().__init__()
# For Linear
self.weight = layer.weight
self.bias = layer.bias
self._name = layer._name
self.input_is_parallel = layer.input_is_parallel
self.is_mp = layer.is_mp
self.model_parallel_group = layer.model_parallel_group
self.linear = layer.linear

# For FakeQuant
self.weight_quanter = None
self.activation_quanter = None
if q_config.weight is not None:
self.weight_quanter = q_config.weight._instance(layer)
if q_config.activation is not None:
self.activation_quanter = q_config.activation._instance(layer)

def forward(self, input):
quant_input = input
quant_weight = self.weight
if self.activation_quanter is not None:
quant_input = self.activation_quanter(input)
if self.weight_quanter is not None:
quant_weight = self.weight_quanter(self.weight)
return self._linear_forward(quant_input, quant_weight)

def _linear_forward(self, input, weight):
if self.input_is_parallel or (not self.is_mp):
input_parallel = input
else:
# split last dim
input_parallel = paddle.distributed.collective._c_split(
input, group=self.model_parallel_group)

if self.is_mp:
output_parallel = self.linear(
input_parallel, weight, name=self._name)
output_ = paddle.distributed.collective._mp_allreduce(
output_parallel,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True)
output = output_ + self.bias if self.bias is not None else output_
else:
output = self.linear(
input_parallel, weight, self.bias, name=self._name)
return output

def weights_to_quanters(self):
return [('weight', 'weight_quanter')]

def activation_quanters(self):
return ['activation_quanter']


class QuantizedColumnParallelLinear(ConvertibleQuantedLayer):
"""
The computational logic of QuantizedColumnParallelLinear is the same as ColumnParallelLinear.
The only difference is that its inputs are all fake quantized.
"""

def __init__(self, layer: Layer, q_config):
super().__init__()
# For Linear
self.weight = layer.weight
self.bias = layer.bias
self._name = layer._name
self.is_mp = layer.is_mp
self.model_parallel_group = layer.model_parallel_group
self.gather_output = layer.gather_output
self.linear = layer.linear

# For FakeQuant
self.weight_quanter = None
self.activation_quanter = None
if q_config.weight is not None:
self.weight_quanter = q_config.weight._instance(layer)
if q_config.activation is not None:
self.activation_quanter = q_config.activation._instance(layer)

def forward(self, input):
quant_input = input
quant_weight = self.weight
if self.activation_quanter is not None:
quant_input = self.activation_quanter(input)
if self.weight_quanter is not None:
quant_weight = self.weight_quanter(self.weight)
return self._linear_forward(quant_input, quant_weight)

def _linear_forward(self, input, weight):
if self.is_mp:
input_parallel = paddle.distributed.collective._c_identity(
input, group=self.model_parallel_group)
else:
input_parallel = input

output_parallel = self.linear(
input_parallel, weight, self.bias, name=self._name)

if self.gather_output and self.is_mp:
output = paddle.distributed.collective._c_concat(
output_parallel, group=self.model_parallel_group)
else:
output = output_parallel
return output

def weights_to_quanters(self):
return [('weight', 'weight_quanter')]

def activation_quanters(self):
return ['activation_quanter']
5 changes: 3 additions & 2 deletions paddleslim/quant/observers/emd.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def cal_min_max(self, inputs):
abs_max_value = float(paddle.max(paddle.flatten(inputs)))
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
s = 0.3
scale_emd = abs_max_value
while s <= 1.0:
scale = s * abs_max_value
s += 0.02
Expand All @@ -78,8 +79,8 @@ def cal_min_max(self, inputs):
emd_loss = float(emd_loss)
if emd_loss <= self._calibration_loss:
self._calibration_loss = emd_loss

return 0, scale
scale_emd = scale
return 0, scale_emd

def cal_thresholds(self):
""" Compute thresholds for MAX function.
Expand Down
5 changes: 3 additions & 2 deletions paddleslim/quant/observers/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def cal_min_max(self, inputs):
abs_max_value = float(paddle.max(paddle.abs(inputs.flatten())))
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
s = 0.3
scale_mse = abs_max_value
while s <= 1.0:
scale = s * abs_max_value
s += 0.02
Expand All @@ -75,8 +76,8 @@ def cal_min_max(self, inputs):
mse_loss = float(((inputs - quant_dequant_var)**2).mean())
if mse_loss <= self.calibration_loss:
self.calibration_loss = mse_loss

return 0, scale
scale_mse = scale
return 0, scale_mse

def cal_thresholds(self):
""" Compute thresholds for MAX function.
Expand Down
2 changes: 1 addition & 1 deletion paddleslim/quant/observers/uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def cal_scales_zero_points(self) -> Tuple[float, float]:
_max = max(self.max_value(), 0.)

if self._symmetric:
self._scale = max(-_min, _max) / (float(_qmax - _qmin) / 2)
self._scale = max(-_min, _max)
if self._sign:
self._zero_point = 0
else:
Expand Down
Loading

0 comments on commit e2b5c16

Please sign in to comment.