Skip to content

Commit

Permalink
Support Load scales from quant model (PaddlePaddle#1790)
Browse files Browse the repository at this point in the history
  • Loading branch information
RachelXu7 authored Aug 21, 2023
1 parent cafd985 commit e9011d3
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 3 deletions.
4 changes: 2 additions & 2 deletions paddleslim/quant/advanced/piecewise_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ def search(self, layer_name, sampled_input, act_abs_max, weight):
else:
smooth_scale_out += final_smooth_scale

if cur_loss < global_loss:
global_loss = cur_loss
if calibration_loss < global_loss:
global_loss = calibration_loss
best_scale = smooth_scale_out
if self.search_piece:
print('Find Better K-Piece {}'.format(k_piece))
Expand Down
3 changes: 3 additions & 0 deletions paddleslim/quant/observers/abs_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def cal_min_max(self, inputs):
def cal_thresholds(self):
""" Compute thresholds for MAX function.
"""
if self._scale is not None:
self._zero_point = 0
return
self._scale, self._zero_point = self.cal_scales_zero_points()

def min_value(self) -> float:
Expand Down
3 changes: 2 additions & 1 deletion paddleslim/quant/observers/abs_max_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def max_value(self) -> float:
def cal_thresholds(self):
""" Compute thresholds for MAX function.
"""
self._scale = self._max
if self._scale is None:
self._scale = self._max
self._zero_point = paddle.zeros_like(self._scale)

def scales(self):
Expand Down
3 changes: 3 additions & 0 deletions paddleslim/quant/observers/avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def cal_min_max(self, inputs):
def cal_thresholds(self):
""" Compute thresholds for MAX function.
"""
if self._scale is not None:
self._zero_point = 0
return
self._min, self._max = self._avg_min, paddle.mean(
paddle.to_tensor(self._avg_list))
self._scale, self._zero_point = self.cal_scales_zero_points()
Expand Down
3 changes: 3 additions & 0 deletions paddleslim/quant/observers/emd.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def cal_min_max(self, inputs):
def cal_thresholds(self):
""" Compute thresholds for MAX function.
"""
if self._scale is not None:
self._zero_point = 0
return
self._min, self._max = self._emd_min, self._emd_max
self._scale, self._zero_point = self.cal_scales_zero_points()

Expand Down
3 changes: 3 additions & 0 deletions paddleslim/quant/observers/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def cal_min_max(self, inputs):
def cal_thresholds(self):
""" Compute thresholds for MAX function.
"""
if self._scale is not None:
self._zero_point = 0
return
self._min, self._max = self._mse_min, self._mse_max
self._scale, self._zero_point = self.cal_scales_zero_points()

Expand Down

0 comments on commit e9011d3

Please sign in to comment.